Update 2026-05-13 16:43:53
This commit is contained in:
@@ -0,0 +1,619 @@
|
||||
"""Test session management with cache-friendly message handling."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
# Test constants
|
||||
MEMORY_WINDOW = 50
|
||||
KEEP_COUNT = MEMORY_WINDOW // 2 # 25
|
||||
|
||||
|
||||
def create_session_with_messages(key: str, count: int, role: str = "user") -> Session:
|
||||
"""Create a session and add the specified number of messages.
|
||||
|
||||
Args:
|
||||
key: Session identifier
|
||||
count: Number of messages to add
|
||||
role: Message role (default: "user")
|
||||
|
||||
Returns:
|
||||
Session with the specified messages
|
||||
"""
|
||||
session = Session(key=key)
|
||||
for i in range(count):
|
||||
session.add_message(role, f"msg{i}")
|
||||
return session
|
||||
|
||||
|
||||
def assert_messages_content(messages: list, start_index: int, end_index: int) -> None:
|
||||
"""Assert that messages contain expected content from start to end index.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
start_index: Expected first message index
|
||||
end_index: Expected last message index
|
||||
"""
|
||||
assert len(messages) > 0
|
||||
assert messages[0]["content"] == f"msg{start_index}"
|
||||
assert messages[-1]["content"] == f"msg{end_index}"
|
||||
|
||||
|
||||
def get_old_messages(session: Session, last_consolidated: int, keep_count: int) -> list:
|
||||
"""Extract messages that would be consolidated using the standard slice logic.
|
||||
|
||||
Args:
|
||||
session: The session containing messages
|
||||
last_consolidated: Index of last consolidated message
|
||||
keep_count: Number of recent messages to keep
|
||||
|
||||
Returns:
|
||||
List of messages that would be consolidated
|
||||
"""
|
||||
return session.messages[last_consolidated:-keep_count]
|
||||
|
||||
|
||||
class TestSessionLastConsolidated:
|
||||
"""Test last_consolidated tracking to avoid duplicate processing."""
|
||||
|
||||
def test_initial_last_consolidated_zero(self) -> None:
|
||||
"""Test that new session starts with last_consolidated=0."""
|
||||
session = Session(key="test:initial")
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
def test_last_consolidated_persistence(self, tmp_path) -> None:
|
||||
"""Test that last_consolidated persists across save/load."""
|
||||
manager = SessionManager(Path(tmp_path))
|
||||
session1 = create_session_with_messages("test:persist", 20)
|
||||
session1.last_consolidated = 15
|
||||
manager.save(session1)
|
||||
|
||||
session2 = manager.get_or_create("test:persist")
|
||||
assert session2.last_consolidated == 15
|
||||
assert len(session2.messages) == 20
|
||||
|
||||
def test_clear_resets_last_consolidated(self) -> None:
|
||||
"""Test that clear() resets last_consolidated to 0."""
|
||||
session = create_session_with_messages("test:clear", 10)
|
||||
session.last_consolidated = 5
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
|
||||
class TestSessionImmutableHistory:
|
||||
"""Test Session message immutability for cache efficiency."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""Test that new session has empty messages list."""
|
||||
session = Session(key="test:initial")
|
||||
assert len(session.messages) == 0
|
||||
|
||||
def test_add_messages_appends_only(self) -> None:
|
||||
"""Test that adding messages only appends, never modifies."""
|
||||
session = Session(key="test:preserve")
|
||||
session.add_message("user", "msg1")
|
||||
session.add_message("assistant", "resp1")
|
||||
session.add_message("user", "msg2")
|
||||
assert len(session.messages) == 3
|
||||
assert session.messages[0]["content"] == "msg1"
|
||||
|
||||
def test_get_history_returns_most_recent(self) -> None:
|
||||
"""Test get_history returns the most recent messages."""
|
||||
session = Session(key="test:history")
|
||||
for i in range(10):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
|
||||
history = session.get_history(max_messages=6)
|
||||
assert len(history) == 6
|
||||
assert history[0]["content"] == "msg7"
|
||||
assert history[-1]["content"] == "resp9"
|
||||
|
||||
def test_get_history_with_all_messages(self) -> None:
|
||||
"""Test get_history with max_messages larger than actual."""
|
||||
session = create_session_with_messages("test:all", 5)
|
||||
history = session.get_history(max_messages=100)
|
||||
assert len(history) == 5
|
||||
assert history[0]["content"] == "msg0"
|
||||
|
||||
def test_get_history_stable_for_same_session(self) -> None:
|
||||
"""Test that get_history returns same content for same max_messages."""
|
||||
session = create_session_with_messages("test:stable", 20)
|
||||
history1 = session.get_history(max_messages=10)
|
||||
history2 = session.get_history(max_messages=10)
|
||||
assert history1 == history2
|
||||
|
||||
def test_messages_list_never_modified(self) -> None:
|
||||
"""Test that messages list is never modified after creation."""
|
||||
session = create_session_with_messages("test:immutable", 5)
|
||||
original_len = len(session.messages)
|
||||
|
||||
session.get_history(max_messages=2)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
for _ in range(10):
|
||||
session.get_history(max_messages=3)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
|
||||
class TestSessionPersistence:
|
||||
"""Test Session persistence and reload."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_manager(self, tmp_path):
|
||||
return SessionManager(Path(tmp_path))
|
||||
|
||||
def test_persistence_roundtrip(self, temp_manager):
|
||||
"""Test that messages persist across save/load."""
|
||||
session1 = create_session_with_messages("test:persistence", 20)
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:persistence")
|
||||
assert len(session2.messages) == 20
|
||||
assert session2.messages[0]["content"] == "msg0"
|
||||
assert session2.messages[-1]["content"] == "msg19"
|
||||
|
||||
def test_get_history_after_reload(self, temp_manager):
|
||||
"""Test that get_history works correctly after reload."""
|
||||
session1 = create_session_with_messages("test:reload", 30)
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:reload")
|
||||
history = session2.get_history(max_messages=10)
|
||||
assert len(history) == 10
|
||||
assert history[0]["content"] == "msg20"
|
||||
assert history[-1]["content"] == "msg29"
|
||||
|
||||
def test_clear_resets_session(self, temp_manager):
|
||||
"""Test that clear() properly resets session."""
|
||||
session = create_session_with_messages("test:clear", 10)
|
||||
assert len(session.messages) == 10
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
|
||||
|
||||
class TestConsolidationTriggerConditions:
|
||||
"""Test consolidation trigger conditions and logic."""
|
||||
|
||||
def test_consolidation_needed_when_messages_exceed_window(self):
|
||||
"""Test consolidation logic: should trigger when messages exceed the window."""
|
||||
session = create_session_with_messages("test:trigger", 60)
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
|
||||
assert total_messages > MEMORY_WINDOW
|
||||
assert messages_to_process > 0
|
||||
|
||||
expected_consolidate_count = total_messages - KEEP_COUNT
|
||||
assert expected_consolidate_count == 35
|
||||
|
||||
def test_consolidation_skipped_when_within_keep_count(self):
|
||||
"""Test consolidation skipped when total messages <= keep_count."""
|
||||
session = create_session_with_messages("test:skip", 20)
|
||||
|
||||
total_messages = len(session.messages)
|
||||
assert total_messages <= KEEP_COUNT
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_consolidation_skipped_when_no_new_messages(self):
|
||||
"""Test consolidation skipped when messages_to_process <= 0."""
|
||||
session = create_session_with_messages("test:already_consolidated", 40)
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||
|
||||
# Add a few more messages
|
||||
for i in range(40, 42):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
assert messages_to_process > 0
|
||||
|
||||
# Simulate last_consolidated catching up
|
||||
session.last_consolidated = total_messages - KEEP_COUNT
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestLastConsolidatedEdgeCases:
|
||||
"""Test last_consolidated edge cases and data corruption scenarios."""
|
||||
|
||||
def test_last_consolidated_exceeds_message_count(self):
|
||||
"""Test behavior when last_consolidated > len(messages) (data corruption)."""
|
||||
session = create_session_with_messages("test:corruption", 10)
|
||||
session.last_consolidated = 20
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
assert messages_to_process <= 0
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, 5)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_last_consolidated_negative_value(self):
|
||||
"""Test behavior with negative last_consolidated (invalid state)."""
|
||||
session = create_session_with_messages("test:negative", 10)
|
||||
session.last_consolidated = -5
|
||||
|
||||
keep_count = 3
|
||||
old_messages = get_old_messages(session, session.last_consolidated, keep_count)
|
||||
|
||||
# messages[-5:-3] with 10 messages gives indices 5,6
|
||||
assert len(old_messages) == 2
|
||||
assert old_messages[0]["content"] == "msg5"
|
||||
assert old_messages[-1]["content"] == "msg6"
|
||||
|
||||
def test_messages_added_after_consolidation(self):
|
||||
"""Test correct behavior when new messages arrive after consolidation."""
|
||||
session = create_session_with_messages("test:new_messages", 40)
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||
|
||||
# Add new messages after consolidation
|
||||
for i in range(40, 50):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
total_messages = len(session.messages)
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
expected_consolidate_count = total_messages - KEEP_COUNT - session.last_consolidated
|
||||
|
||||
assert len(old_messages) == expected_consolidate_count
|
||||
assert_messages_content(old_messages, 15, 24)
|
||||
|
||||
def test_slice_behavior_when_indices_overlap(self):
|
||||
"""Test slice behavior when last_consolidated >= total - keep_count."""
|
||||
session = create_session_with_messages("test:overlap", 30)
|
||||
session.last_consolidated = 12
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, 20)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestArchiveAllMode:
|
||||
"""Test archive_all mode (used by /new command)."""
|
||||
|
||||
def test_archive_all_consolidates_everything(self):
|
||||
"""Test archive_all=True consolidates all messages."""
|
||||
session = create_session_with_messages("test:archive_all", 50)
|
||||
|
||||
archive_all = True
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
assert len(old_messages) == 50
|
||||
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
def test_archive_all_resets_last_consolidated(self):
|
||||
"""Test that archive_all mode resets last_consolidated to 0."""
|
||||
session = create_session_with_messages("test:reset", 40)
|
||||
session.last_consolidated = 15
|
||||
|
||||
archive_all = True
|
||||
if archive_all:
|
||||
session.last_consolidated = 0
|
||||
|
||||
assert session.last_consolidated == 0
|
||||
assert len(session.messages) == 40
|
||||
|
||||
def test_archive_all_vs_normal_consolidation(self):
|
||||
"""Test difference between archive_all and normal consolidation."""
|
||||
# Normal consolidation
|
||||
session1 = create_session_with_messages("test:normal", 60)
|
||||
session1.last_consolidated = len(session1.messages) - KEEP_COUNT
|
||||
|
||||
# archive_all mode
|
||||
session2 = create_session_with_messages("test:all", 60)
|
||||
session2.last_consolidated = 0
|
||||
|
||||
assert session1.last_consolidated == 35
|
||||
assert len(session1.messages) == 60
|
||||
assert session2.last_consolidated == 0
|
||||
assert len(session2.messages) == 60
|
||||
|
||||
|
||||
class TestCacheImmutability:
|
||||
"""Test that consolidation doesn't modify session.messages (cache safety)."""
|
||||
|
||||
def test_consolidation_does_not_modify_messages_list(self):
|
||||
"""Test that consolidation leaves messages list unchanged."""
|
||||
session = create_session_with_messages("test:immutable", 50)
|
||||
|
||||
original_messages = session.messages.copy()
|
||||
original_len = len(session.messages)
|
||||
session.last_consolidated = original_len - KEEP_COUNT
|
||||
|
||||
assert len(session.messages) == original_len
|
||||
assert session.messages == original_messages
|
||||
|
||||
def test_get_history_does_not_modify_messages(self):
|
||||
"""Test that get_history doesn't modify messages list."""
|
||||
session = create_session_with_messages("test:history_immutable", 40)
|
||||
original_messages = [m.copy() for m in session.messages]
|
||||
|
||||
for _ in range(5):
|
||||
history = session.get_history(max_messages=10)
|
||||
assert len(history) == 10
|
||||
|
||||
assert len(session.messages) == 40
|
||||
for i, msg in enumerate(session.messages):
|
||||
assert msg["content"] == original_messages[i]["content"]
|
||||
|
||||
def test_consolidation_only_updates_last_consolidated(self):
|
||||
"""Test that consolidation only updates last_consolidated field."""
|
||||
session = create_session_with_messages("test:field_only", 60)
|
||||
|
||||
original_messages = session.messages.copy()
|
||||
original_key = session.key
|
||||
original_metadata = session.metadata.copy()
|
||||
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT
|
||||
|
||||
assert session.messages == original_messages
|
||||
assert session.key == original_key
|
||||
assert session.metadata == original_metadata
|
||||
assert session.last_consolidated == 35
|
||||
|
||||
|
||||
class TestSliceLogic:
|
||||
"""Test the slice logic: messages[last_consolidated:-keep_count]."""
|
||||
|
||||
def test_slice_extracts_correct_range(self):
|
||||
"""Test that slice extracts the correct message range."""
|
||||
session = create_session_with_messages("test:slice", 60)
|
||||
|
||||
old_messages = get_old_messages(session, 0, KEEP_COUNT)
|
||||
|
||||
assert len(old_messages) == 35
|
||||
assert_messages_content(old_messages, 0, 34)
|
||||
|
||||
remaining = session.messages[-KEEP_COUNT:]
|
||||
assert len(remaining) == 25
|
||||
assert_messages_content(remaining, 35, 59)
|
||||
|
||||
def test_slice_with_partial_consolidation(self):
|
||||
"""Test slice when some messages already consolidated."""
|
||||
session = create_session_with_messages("test:partial", 70)
|
||||
|
||||
last_consolidated = 30
|
||||
old_messages = get_old_messages(session, last_consolidated, KEEP_COUNT)
|
||||
|
||||
assert len(old_messages) == 15
|
||||
assert_messages_content(old_messages, 30, 44)
|
||||
|
||||
def test_slice_with_various_keep_counts(self):
|
||||
"""Test slice behavior with different keep_count values."""
|
||||
session = create_session_with_messages("test:keep_counts", 50)
|
||||
|
||||
test_cases = [(10, 40), (20, 30), (30, 20), (40, 10)]
|
||||
|
||||
for keep_count, expected_count in test_cases:
|
||||
old_messages = session.messages[0:-keep_count]
|
||||
assert len(old_messages) == expected_count
|
||||
|
||||
def test_slice_when_keep_count_exceeds_messages(self):
|
||||
"""Test slice when keep_count > len(messages)."""
|
||||
session = create_session_with_messages("test:exceed", 10)
|
||||
|
||||
old_messages = session.messages[0:-20]
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestEmptyAndBoundarySessions:
|
||||
"""Test empty sessions and boundary conditions."""
|
||||
|
||||
def test_empty_session_consolidation(self):
|
||||
"""Test consolidation behavior with empty session."""
|
||||
session = Session(key="test:empty")
|
||||
|
||||
assert len(session.messages) == 0
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
messages_to_process = len(session.messages) - session.last_consolidated
|
||||
assert messages_to_process == 0
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_single_message_session(self):
|
||||
"""Test consolidation with single message."""
|
||||
session = Session(key="test:single")
|
||||
session.add_message("user", "only message")
|
||||
|
||||
assert len(session.messages) == 1
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_exactly_keep_count_messages(self):
|
||||
"""Test session with exactly keep_count messages."""
|
||||
session = create_session_with_messages("test:exact", KEEP_COUNT)
|
||||
|
||||
assert len(session.messages) == KEEP_COUNT
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_just_over_keep_count(self):
|
||||
"""Test session with one message over keep_count."""
|
||||
session = create_session_with_messages("test:over", KEEP_COUNT + 1)
|
||||
|
||||
assert len(session.messages) == 26
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 1
|
||||
assert old_messages[0]["content"] == "msg0"
|
||||
|
||||
def test_very_large_session(self):
|
||||
"""Test consolidation with very large message count."""
|
||||
session = create_session_with_messages("test:large", 1000)
|
||||
|
||||
assert len(session.messages) == 1000
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 975
|
||||
assert_messages_content(old_messages, 0, 974)
|
||||
|
||||
remaining = session.messages[-KEEP_COUNT:]
|
||||
assert len(remaining) == 25
|
||||
assert_messages_content(remaining, 975, 999)
|
||||
|
||||
def test_session_with_gaps_in_consolidation(self):
|
||||
"""Test session with potential gaps in consolidation history."""
|
||||
session = create_session_with_messages("test:gaps", 50)
|
||||
session.last_consolidated = 10
|
||||
|
||||
# Add more messages
|
||||
for i in range(50, 60):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
|
||||
expected_count = 60 - KEEP_COUNT - 10
|
||||
assert len(old_messages) == expected_count
|
||||
assert_messages_content(old_messages, 10, 34)
|
||||
|
||||
|
||||
class TestNewCommandArchival:
|
||||
"""Test /new archival behavior with the simplified consolidation flow."""
|
||||
|
||||
@staticmethod
|
||||
def _make_loop(tmp_path: Path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=1,
|
||||
)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
return loop
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_clears_session_immediately_even_if_archive_fails(self, tmp_path: Path) -> None:
|
||||
"""/new clears session immediately; archive_messages retries until raw dump."""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _failing_consolidate(_messages) -> bool:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return False
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == 0
|
||||
|
||||
await loop.close_mcp()
|
||||
assert call_count == 3 # retried up to raw-archive threshold
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
session.last_consolidated = len(session.messages) - 3
|
||||
loop.sessions.save(session)
|
||||
|
||||
archived_count = -1
|
||||
|
||||
async def _fake_consolidate(messages) -> bool:
|
||||
nonlocal archived_count
|
||||
archived_count = len(messages)
|
||||
return True
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
|
||||
await loop.close_mcp()
|
||||
assert archived_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(3):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _ok_consolidate(_messages) -> bool:
|
||||
return True
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert loop.sessions.get_or_create("cli:test").messages == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_mcp_drains_background_tasks(self, tmp_path: Path) -> None:
|
||||
"""close_mcp waits for background tasks to complete."""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(3):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
archived = asyncio.Event()
|
||||
|
||||
async def _slow_consolidate(_messages) -> bool:
|
||||
await asyncio.sleep(0.1)
|
||||
archived.set()
|
||||
return True
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _slow_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
await loop._process_message(new_msg)
|
||||
|
||||
assert not archived.is_set()
|
||||
await loop.close_mcp()
|
||||
assert archived.is_set()
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Tests for cache-friendly prompt construction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime as real_datetime
|
||||
from importlib.resources import files as pkg_files
|
||||
from pathlib import Path
|
||||
import datetime as datetime_module
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
|
||||
class _FakeDatetime(real_datetime):
|
||||
current = real_datetime(2026, 2, 24, 13, 59)
|
||||
|
||||
@classmethod
|
||||
def now(cls, tz=None): # type: ignore[override]
|
||||
return cls.current
|
||||
|
||||
|
||||
def _make_workspace(tmp_path: Path) -> Path:
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir(parents=True)
|
||||
return workspace
|
||||
|
||||
|
||||
def test_bootstrap_files_are_backed_by_templates() -> None:
|
||||
template_dir = pkg_files("nanobot") / "templates"
|
||||
|
||||
for filename in ContextBuilder.BOOTSTRAP_FILES:
|
||||
assert (template_dir / filename).is_file(), f"missing bootstrap template: {filename}"
|
||||
|
||||
|
||||
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
|
||||
"""System prompt should not change just because wall clock minute changes."""
|
||||
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
|
||||
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
_FakeDatetime.current = real_datetime(2026, 2, 24, 13, 59)
|
||||
prompt1 = builder.build_system_prompt()
|
||||
|
||||
_FakeDatetime.current = real_datetime(2026, 2, 24, 14, 0)
|
||||
prompt2 = builder.build_system_prompt()
|
||||
|
||||
assert prompt1 == prompt2
|
||||
|
||||
|
||||
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||
"""Runtime metadata should be merged with the user message."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
messages = builder.build_messages(
|
||||
history=[],
|
||||
current_message="Return exactly: OK",
|
||||
channel="cli",
|
||||
chat_id="direct",
|
||||
)
|
||||
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "## Current Session" not in messages[0]["content"]
|
||||
|
||||
# Runtime context is now merged with user message into a single message
|
||||
assert messages[-1]["role"] == "user"
|
||||
user_content = messages[-1]["content"]
|
||||
assert isinstance(user_content, str)
|
||||
assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
|
||||
assert "Current Time:" in user_content
|
||||
assert "Channel: cli" in user_content
|
||||
assert "Chat ID: direct" in user_content
|
||||
assert "Return exactly: OK" in user_content
|
||||
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.utils.evaluator import evaluate_response
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class DummyProvider(LLMProvider):
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
def _eval_tool_call(should_notify: bool, reason: str = "") -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="eval_1",
|
||||
name="evaluate_notification",
|
||||
arguments={"should_notify": should_notify, "reason": reason},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_notify_true() -> None:
|
||||
provider = DummyProvider([_eval_tool_call(True, "user asked to be reminded")])
|
||||
result = await evaluate_response("Task completed with results", "check emails", provider, "m")
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_notify_false() -> None:
|
||||
provider = DummyProvider([_eval_tool_call(False, "routine check, nothing new")])
|
||||
result = await evaluate_response("All clear, no updates", "check status", provider, "m")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_error() -> None:
|
||||
class FailingProvider(DummyProvider):
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
raise RuntimeError("provider down")
|
||||
|
||||
provider = FailingProvider([])
|
||||
result = await evaluate_response("some response", "some task", provider, "m")
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_call_fallback() -> None:
|
||||
provider = DummyProvider([LLMResponse(content="I think you should notify", tool_calls=[])])
|
||||
result = await evaluate_response("some response", "some task", provider, "m")
|
||||
assert result is True
|
||||
@@ -0,0 +1,200 @@
|
||||
"""Tests for Gemini thought_signature round-trip through extra_content.
|
||||
|
||||
The Gemini OpenAI-compatibility API returns tool calls with an extra_content
|
||||
field: ``{"google": {"thought_signature": "..."}}``. This MUST survive the
|
||||
parse → serialize round-trip so the model can continue reasoning.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from nanobot.providers.base import ToolCallRequest
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
GEMINI_EXTRA = {"google": {"thought_signature": "sig-abc-123"}}
|
||||
|
||||
|
||||
# ── ToolCallRequest serialization ──────────────────────────────────────
|
||||
|
||||
def test_tool_call_request_serializes_extra_content() -> None:
|
||||
tc = ToolCallRequest(
|
||||
id="abc123xyz",
|
||||
name="read_file",
|
||||
arguments={"path": "todo.md"},
|
||||
extra_content=GEMINI_EXTRA,
|
||||
)
|
||||
|
||||
payload = tc.to_openai_tool_call()
|
||||
|
||||
assert payload["extra_content"] == GEMINI_EXTRA
|
||||
assert payload["function"]["arguments"] == '{"path": "todo.md"}'
|
||||
|
||||
|
||||
def test_tool_call_request_serializes_provider_fields() -> None:
|
||||
tc = ToolCallRequest(
|
||||
id="abc123xyz",
|
||||
name="read_file",
|
||||
arguments={"path": "todo.md"},
|
||||
provider_specific_fields={"custom_key": "custom_val"},
|
||||
function_provider_specific_fields={"inner": "value"},
|
||||
)
|
||||
|
||||
payload = tc.to_openai_tool_call()
|
||||
|
||||
assert payload["provider_specific_fields"] == {"custom_key": "custom_val"}
|
||||
assert payload["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||
|
||||
|
||||
def test_tool_call_request_omits_absent_extras() -> None:
|
||||
tc = ToolCallRequest(id="x", name="fn", arguments={})
|
||||
payload = tc.to_openai_tool_call()
|
||||
|
||||
assert "extra_content" not in payload
|
||||
assert "provider_specific_fields" not in payload
|
||||
assert "provider_specific_fields" not in payload["function"]
|
||||
|
||||
|
||||
# ── _parse: SDK-object branch ──────────────────────────────────────────
|
||||
|
||||
def _make_sdk_response_with_extra_content():
|
||||
"""Simulate a Gemini response via the OpenAI SDK (SimpleNamespace)."""
|
||||
fn = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}')
|
||||
tc = SimpleNamespace(
|
||||
id="call_1",
|
||||
index=0,
|
||||
type="function",
|
||||
function=fn,
|
||||
extra_content=GEMINI_EXTRA,
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
content=None,
|
||||
tool_calls=[tc],
|
||||
reasoning_content=None,
|
||||
)
|
||||
choice = SimpleNamespace(message=msg, finish_reason="tool_calls")
|
||||
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def test_parse_sdk_object_preserves_extra_content() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
result = provider._parse(_make_sdk_response_with_extra_content())
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.name == "get_weather"
|
||||
assert tc.extra_content == GEMINI_EXTRA
|
||||
|
||||
payload = tc.to_openai_tool_call()
|
||||
assert payload["extra_content"] == GEMINI_EXTRA
|
||||
|
||||
|
||||
# ── _parse: dict/mapping branch ───────────────────────────────────────
|
||||
|
||||
def test_parse_dict_preserves_extra_content() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
response_dict = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'},
|
||||
"extra_content": GEMINI_EXTRA,
|
||||
}],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
|
||||
result = provider._parse(response_dict)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.name == "get_weather"
|
||||
assert tc.extra_content == GEMINI_EXTRA
|
||||
|
||||
payload = tc.to_openai_tool_call()
|
||||
assert payload["extra_content"] == GEMINI_EXTRA
|
||||
|
||||
|
||||
# ── _parse_chunks: streaming round-trip ───────────────────────────────
|
||||
|
||||
def test_parse_chunks_sdk_preserves_extra_content() -> None:
|
||||
fn_delta = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}')
|
||||
tc_delta = SimpleNamespace(
|
||||
id="call_1",
|
||||
index=0,
|
||||
function=fn_delta,
|
||||
extra_content=GEMINI_EXTRA,
|
||||
)
|
||||
delta = SimpleNamespace(content=None, tool_calls=[tc_delta])
|
||||
choice = SimpleNamespace(finish_reason="tool_calls", delta=delta)
|
||||
chunk = SimpleNamespace(choices=[choice], usage=None)
|
||||
|
||||
result = OpenAICompatProvider._parse_chunks([chunk])
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.extra_content == GEMINI_EXTRA
|
||||
|
||||
payload = tc.to_openai_tool_call()
|
||||
assert payload["extra_content"] == GEMINI_EXTRA
|
||||
|
||||
|
||||
def test_parse_chunks_dict_preserves_extra_content() -> None:
|
||||
chunk = {
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls",
|
||||
"delta": {
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"index": 0,
|
||||
"id": "call_1",
|
||||
"function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'},
|
||||
"extra_content": GEMINI_EXTRA,
|
||||
}],
|
||||
},
|
||||
}],
|
||||
}
|
||||
|
||||
result = OpenAICompatProvider._parse_chunks([chunk])
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.extra_content == GEMINI_EXTRA
|
||||
|
||||
payload = tc.to_openai_tool_call()
|
||||
assert payload["extra_content"] == GEMINI_EXTRA
|
||||
|
||||
|
||||
# ── Model switching: stale extras shouldn't break other providers ─────
|
||||
|
||||
def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None:
|
||||
"""When switching from Gemini to OpenAI, extra_content inside tool_calls
|
||||
should survive message sanitization (it lives inside the tool_call dict,
|
||||
not at message level, so it bypasses _ALLOWED_MSG_KEYS filtering)."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "fn", "arguments": "{}"},
|
||||
"extra_content": GEMINI_EXTRA,
|
||||
}],
|
||||
}]
|
||||
|
||||
sanitized = provider._sanitize_messages(messages)
|
||||
|
||||
assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
|
||||
@@ -0,0 +1,289 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class DummyProvider(LLMProvider):
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_is_idempotent(tmp_path) -> None:
|
||||
provider = DummyProvider([])
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
interval_s=9999,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
await service.start()
|
||||
first_task = service._task
|
||||
await service.start()
|
||||
|
||||
assert service._task is first_task
|
||||
|
||||
service.stop()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None:
|
||||
provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])])
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
)
|
||||
|
||||
action, tasks = await service._decide("heartbeat content")
|
||||
assert action == "skip"
|
||||
assert tasks == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None:
|
||||
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
|
||||
|
||||
provider = DummyProvider([
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check open tasks"},
|
||||
)
|
||||
],
|
||||
)
|
||||
])
|
||||
|
||||
called_with: list[str] = []
|
||||
|
||||
async def _on_execute(tasks: str) -> str:
|
||||
called_with.append(tasks)
|
||||
return "done"
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
on_execute=_on_execute,
|
||||
)
|
||||
|
||||
result = await service.trigger_now()
|
||||
assert result == "done"
|
||||
assert called_with == ["check open tasks"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
|
||||
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
|
||||
|
||||
provider = DummyProvider([
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "skip"},
|
||||
)
|
||||
],
|
||||
)
|
||||
])
|
||||
|
||||
async def _on_execute(tasks: str) -> str:
|
||||
return tasks
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
on_execute=_on_execute,
|
||||
)
|
||||
|
||||
assert await service.trigger_now() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_notifies_when_evaluator_says_yes(tmp_path, monkeypatch) -> None:
|
||||
"""Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=notify -> on_notify called."""
|
||||
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check deployments", encoding="utf-8")
|
||||
|
||||
provider = DummyProvider([
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check deployments"},
|
||||
)
|
||||
],
|
||||
),
|
||||
])
|
||||
|
||||
executed: list[str] = []
|
||||
notified: list[str] = []
|
||||
|
||||
async def _on_execute(tasks: str) -> str:
|
||||
executed.append(tasks)
|
||||
return "deployment failed on staging"
|
||||
|
||||
async def _on_notify(response: str) -> None:
|
||||
notified.append(response)
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
on_execute=_on_execute,
|
||||
on_notify=_on_notify,
|
||||
)
|
||||
|
||||
async def _eval_notify(*a, **kw):
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_notify)
|
||||
|
||||
await service._tick()
|
||||
assert executed == ["check deployments"]
|
||||
assert notified == ["deployment failed on staging"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) -> None:
|
||||
"""Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=silent -> on_notify NOT called."""
|
||||
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check status", encoding="utf-8")
|
||||
|
||||
provider = DummyProvider([
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check status"},
|
||||
)
|
||||
],
|
||||
),
|
||||
])
|
||||
|
||||
executed: list[str] = []
|
||||
notified: list[str] = []
|
||||
|
||||
async def _on_execute(tasks: str) -> str:
|
||||
executed.append(tasks)
|
||||
return "everything is fine, no issues"
|
||||
|
||||
async def _on_notify(response: str) -> None:
|
||||
notified.append(response)
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
on_execute=_on_execute,
|
||||
on_notify=_on_notify,
|
||||
)
|
||||
|
||||
async def _eval_silent(*a, **kw):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_silent)
|
||||
|
||||
await service._tick()
|
||||
assert executed == ["check status"]
|
||||
assert notified == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
||||
provider = DummyProvider([
|
||||
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check open tasks"},
|
||||
)
|
||||
],
|
||||
),
|
||||
])
|
||||
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
)
|
||||
|
||||
action, tasks = await service._decide("heartbeat content")
|
||||
|
||||
assert action == "run"
|
||||
assert tasks == "check open tasks"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_prompt_includes_current_time(tmp_path) -> None:
|
||||
"""Phase 1 user prompt must contain current time so the LLM can judge task urgency."""
|
||||
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
class CapturingProvider(LLMProvider):
|
||||
async def chat(self, *, messages=None, **kwargs) -> LLMResponse:
|
||||
if messages:
|
||||
captured_messages.extend(messages)
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1", name="heartbeat",
|
||||
arguments={"action": "skip"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=CapturingProvider(),
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
await service._decide("- [ ] check servers at 10:00 UTC")
|
||||
|
||||
user_msg = captured_messages[1]
|
||||
assert user_msg["role"] == "user"
|
||||
assert "Current Time:" in user_msg["content"]
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
import nanobot.agent.memory as memory_module
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings(max_tokens=0)
|
||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||
_response = LLMResponse(content="ok", tool_calls=[])
|
||||
provider.chat_with_retry = AsyncMock(return_value=_response)
|
||||
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=context_window_tokens,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.memory_consolidator._SAFETY_BUFFER = 0
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
|
||||
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
|
||||
assert session.last_consolidated == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
|
||||
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (300, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
|
||||
"""Once triggered, consolidation should continue until it drops below half threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (150, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
|
||||
"""Verify preflight consolidation runs before the LLM call in process_direct."""
|
||||
order: list[str] = []
|
||||
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
|
||||
async def track_consolidate(messages):
|
||||
order.append("consolidate")
|
||||
return True
|
||||
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
||||
|
||||
async def track_llm(*args, **kwargs):
|
||||
order.append("llm")
|
||||
return LLMResponse(content="ok", tool_calls=[])
|
||||
loop.provider.chat_with_retry = track_llm
|
||||
loop.provider.chat_stream_with_retry = track_llm
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
return (1000 if call_count[0] <= 1 else 80, "test")
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert "consolidate" in order
|
||||
assert "llm" in order
|
||||
assert order.index("consolidate") < order.index("llm")
|
||||
@@ -0,0 +1,27 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
def test_agent_loop_registers_cron_tool_with_configured_timezone(tmp_path: Path) -> None:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
cron_service=CronService(tmp_path / "cron" / "jobs.json"),
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
cron_tool = loop.tools.get("cron")
|
||||
|
||||
assert isinstance(cron_tool, CronTool)
|
||||
assert cron_tool._default_timezone == "Asia/Shanghai"
|
||||
@@ -0,0 +1,74 @@
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
|
||||
def _mk_loop() -> AgentLoop:
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
|
||||
return loop
|
||||
|
||||
|
||||
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:runtime-only")
|
||||
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{"role": "user", "content": [{"type": "text", "text": runtime}]}],
|
||||
skip=0,
|
||||
)
|
||||
assert session.messages == []
|
||||
|
||||
|
||||
def test_save_turn_keeps_image_placeholder_with_path_after_runtime_strip() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:image")
|
||||
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": runtime},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}, "_meta": {"path": "/media/feishu/photo.jpg"}},
|
||||
],
|
||||
}],
|
||||
skip=0,
|
||||
)
|
||||
assert session.messages[0]["content"] == [{"type": "text", "text": "[image: /media/feishu/photo.jpg]"}]
|
||||
|
||||
|
||||
def test_save_turn_keeps_image_placeholder_without_meta() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:image-no-meta")
|
||||
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": runtime},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
],
|
||||
}],
|
||||
skip=0,
|
||||
)
|
||||
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
|
||||
|
||||
|
||||
def test_save_turn_keeps_tool_results_under_16k() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:tool-result")
|
||||
content = "x" * 12_000
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
|
||||
skip=0,
|
||||
)
|
||||
|
||||
assert session.messages[0]["content"] == content
|
||||
@@ -0,0 +1,478 @@
|
||||
"""Test MemoryStore.consolidate() handles non-string tool call arguments.
|
||||
|
||||
Regression test for https://github.com/HKUDS/nanobot/issues/1042
|
||||
When memory consolidation receives dict values instead of strings from the LLM
|
||||
tool call response, it should serialize them to JSON instead of raising TypeError.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_messages(message_count: int = 30):
|
||||
"""Create a list of mock messages."""
|
||||
return [
|
||||
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||
for i in range(message_count)
|
||||
]
|
||||
|
||||
|
||||
def _make_tool_response(history_entry, memory_update):
|
||||
"""Create an LLMResponse with a save_memory tool call."""
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments={
|
||||
"history_entry": history_entry,
|
||||
"memory_update": memory_update,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ScriptedProvider(LLMProvider):
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
class TestMemoryConsolidationTypeHandling:
|
||||
"""Test that consolidation handles various argument types correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_arguments_work(self, tmp_path: Path) -> None:
|
||||
"""Normal case: LLM returns string arguments."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
assert "[2026-01-01] User discussed testing." in store.history_file.read_text()
|
||||
assert "User likes testing." in store.memory_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None:
|
||||
"""Issue #1042: LLM returns dict instead of string — must not raise TypeError."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."},
|
||||
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
||||
)
|
||||
)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
history_content = store.history_file.read_text()
|
||||
parsed = json.loads(history_content.strip())
|
||||
assert parsed["summary"] == "User discussed testing."
|
||||
|
||||
memory_content = store.memory_file.read_text()
|
||||
parsed_mem = json.loads(memory_content)
|
||||
assert "User likes testing" in parsed_mem["facts"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None:
|
||||
"""Some providers return arguments as a JSON string instead of parsed dict."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=json.dumps({
|
||||
"history_entry": "[2026-01-01] User discussed testing.",
|
||||
"memory_update": "# Memory\nUser likes testing.",
|
||||
}),
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None:
|
||||
"""When LLM doesn't use the save_memory tool, return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||
)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when the selected chunk is empty."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages: list[dict] = []
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
provider.chat.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
|
||||
"""Some providers return arguments as a list - extract first element if it's a dict."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=[{
|
||||
"history_entry": "[2026-01-01] User discussed testing.",
|
||||
"memory_update": "# Memory\nUser likes testing.",
|
||||
}],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
assert "User likes testing." in store.memory_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
|
||||
"""Empty list arguments should return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
|
||||
"""List with non-dict content should return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=["string", "content"],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Do not persist partial results when required fields are missing."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments={"memory_update": "# Memory\nOnly memory update"},
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Do not append history if memory_update is missing."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments={"history_entry": "[2026-01-01] Partial output."},
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Null required fields should be rejected before persistence."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry=None,
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Empty history entries should be rejected to avoid blank archival records."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry=" ",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="503 server error", finish_reason="error"),
|
||||
_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
),
|
||||
])
|
||||
messages = _make_messages(message_count=60)
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
|
||||
"""Consolidation no longer passes generation params — the provider owns them."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
provider.chat_with_retry.assert_awaited_once()
|
||||
_, kwargs = provider.chat_with_retry.await_args
|
||||
assert kwargs["model"] == "test-model"
|
||||
assert "temperature" not in kwargs
|
||||
assert "max_tokens" not in kwargs
|
||||
assert "reasoning_effort" not in kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
|
||||
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
|
||||
store = MemoryStore(tmp_path)
|
||||
error_resp = LLMResponse(
|
||||
content="Error calling LLM: BadRequestError: "
|
||||
"The tool_choice parameter does not support being set to required or object",
|
||||
finish_reason="error",
|
||||
tool_calls=[],
|
||||
)
|
||||
ok_resp = _make_tool_response(
|
||||
history_entry="[2026-01-01] Fallback worked.",
|
||||
memory_update="# Memory\nFallback OK.",
|
||||
)
|
||||
|
||||
call_log: list[dict] = []
|
||||
|
||||
async def _tracking_chat(**kwargs):
|
||||
call_log.append(kwargs)
|
||||
return error_resp if len(call_log) == 1 else ok_resp
|
||||
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert len(call_log) == 2
|
||||
assert isinstance(call_log[0]["tool_choice"], dict)
|
||||
assert call_log[1]["tool_choice"] == "auto"
|
||||
assert "Fallback worked." in store.history_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
|
||||
"""Forced rejected, auto retry also produces no tool call -> return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
error_resp = LLMResponse(
|
||||
content="Error: tool_choice must be none or auto",
|
||||
finish_reason="error",
|
||||
tool_calls=[],
|
||||
)
|
||||
no_tool_resp = LLMResponse(
|
||||
content="Here is a summary.",
|
||||
finish_reason="stop",
|
||||
tool_calls=[],
|
||||
)
|
||||
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
|
||||
"""After 3 consecutive failures, raw-archive messages and return True."""
|
||||
store = MemoryStore(tmp_path)
|
||||
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(return_value=no_tool)
|
||||
messages = _make_messages(message_count=10)
|
||||
|
||||
assert await store.consolidate(messages, provider, "m") is False
|
||||
assert await store.consolidate(messages, provider, "m") is False
|
||||
assert await store.consolidate(messages, provider, "m") is True
|
||||
|
||||
assert store.history_file.exists()
|
||||
content = store.history_file.read_text()
|
||||
assert "[RAW]" in content
|
||||
assert "10 messages" in content
|
||||
assert "msg0" in content
|
||||
assert not store.memory_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
|
||||
"""A successful consolidation resets the failure counter."""
|
||||
store = MemoryStore(tmp_path)
|
||||
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
|
||||
ok_resp = _make_tool_response(
|
||||
history_entry="[2026-01-01] OK.",
|
||||
memory_update="# Memory\nOK.",
|
||||
)
|
||||
messages = _make_messages(message_count=10)
|
||||
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(return_value=no_tool)
|
||||
assert await store.consolidate(messages, provider, "m") is False
|
||||
assert await store.consolidate(messages, provider, "m") is False
|
||||
assert store._consecutive_failures == 2
|
||||
|
||||
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
|
||||
assert await store.consolidate(messages, provider, "m") is True
|
||||
assert store._consecutive_failures == 0
|
||||
|
||||
provider.chat_with_retry = AsyncMock(return_value=no_tool)
|
||||
assert await store.consolidate(messages, provider, "m") is False
|
||||
assert store._consecutive_failures == 1
|
||||
@@ -0,0 +1,495 @@
|
||||
"""Unit tests for onboard core logic functions.
|
||||
|
||||
These tests focus on the business logic behind the onboard wizard,
|
||||
without testing the interactive UI components.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from nanobot.cli import onboard as onboard_wizard
|
||||
|
||||
# Import functions to test
|
||||
from nanobot.cli.commands import _merge_missing_defaults
|
||||
from nanobot.cli.onboard import (
|
||||
_BACK_PRESSED,
|
||||
_configure_pydantic_model,
|
||||
_format_value,
|
||||
_get_field_display_name,
|
||||
_get_field_type_info,
|
||||
run_onboard,
|
||||
)
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.utils.helpers import sync_workspace_templates
|
||||
|
||||
|
||||
class TestMergeMissingDefaults:
|
||||
"""Tests for _merge_missing_defaults recursive config merging."""
|
||||
|
||||
def test_adds_missing_top_level_keys(self):
|
||||
existing = {"a": 1}
|
||||
defaults = {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
result = _merge_missing_defaults(existing, defaults)
|
||||
|
||||
assert result == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
def test_preserves_existing_values(self):
|
||||
existing = {"a": "custom_value"}
|
||||
defaults = {"a": "default_value"}
|
||||
|
||||
result = _merge_missing_defaults(existing, defaults)
|
||||
|
||||
assert result == {"a": "custom_value"}
|
||||
|
||||
def test_merges_nested_dicts_recursively(self):
|
||||
existing = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"existing": "kept",
|
||||
}
|
||||
}
|
||||
}
|
||||
defaults = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"existing": "replaced",
|
||||
"added": "new",
|
||||
},
|
||||
"level2b": "also_new",
|
||||
}
|
||||
}
|
||||
|
||||
result = _merge_missing_defaults(existing, defaults)
|
||||
|
||||
assert result == {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"existing": "kept",
|
||||
"added": "new",
|
||||
},
|
||||
"level2b": "also_new",
|
||||
}
|
||||
}
|
||||
|
||||
def test_returns_existing_if_not_dict(self):
|
||||
assert _merge_missing_defaults("string", {"a": 1}) == "string"
|
||||
assert _merge_missing_defaults([1, 2, 3], {"a": 1}) == [1, 2, 3]
|
||||
assert _merge_missing_defaults(None, {"a": 1}) is None
|
||||
assert _merge_missing_defaults(42, {"a": 1}) == 42
|
||||
|
||||
def test_returns_existing_if_defaults_not_dict(self):
|
||||
assert _merge_missing_defaults({"a": 1}, "string") == {"a": 1}
|
||||
assert _merge_missing_defaults({"a": 1}, None) == {"a": 1}
|
||||
|
||||
def test_handles_empty_dicts(self):
|
||||
assert _merge_missing_defaults({}, {"a": 1}) == {"a": 1}
|
||||
assert _merge_missing_defaults({"a": 1}, {}) == {"a": 1}
|
||||
assert _merge_missing_defaults({}, {}) == {}
|
||||
|
||||
def test_backfills_channel_config(self):
|
||||
"""Real-world scenario: backfill missing channel fields."""
|
||||
existing_channel = {
|
||||
"enabled": False,
|
||||
"appId": "",
|
||||
"secret": "",
|
||||
}
|
||||
default_channel = {
|
||||
"enabled": False,
|
||||
"appId": "",
|
||||
"secret": "",
|
||||
"msgFormat": "plain",
|
||||
"allowFrom": [],
|
||||
}
|
||||
|
||||
result = _merge_missing_defaults(existing_channel, default_channel)
|
||||
|
||||
assert result["msgFormat"] == "plain"
|
||||
assert result["allowFrom"] == []
|
||||
|
||||
|
||||
class TestGetFieldTypeInfo:
|
||||
"""Tests for _get_field_type_info type extraction."""
|
||||
|
||||
def test_extracts_str_type(self):
|
||||
class Model(BaseModel):
|
||||
field: str
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["field"])
|
||||
assert type_name == "str"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_int_type(self):
|
||||
class Model(BaseModel):
|
||||
count: int
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["count"])
|
||||
assert type_name == "int"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_bool_type(self):
|
||||
class Model(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["enabled"])
|
||||
assert type_name == "bool"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_float_type(self):
|
||||
class Model(BaseModel):
|
||||
ratio: float
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["ratio"])
|
||||
assert type_name == "float"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_list_type_with_item_type(self):
|
||||
class Model(BaseModel):
|
||||
items: list[str]
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["items"])
|
||||
assert type_name == "list"
|
||||
assert inner is str
|
||||
|
||||
def test_extracts_list_type_without_item_type(self):
|
||||
# Plain list without type param falls back to str
|
||||
class Model(BaseModel):
|
||||
items: list # type: ignore
|
||||
|
||||
# Plain list annotation doesn't match list check, returns str
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["items"])
|
||||
assert type_name == "str" # Falls back to str for untyped list
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_dict_type(self):
|
||||
# Plain dict without type param falls back to str
|
||||
class Model(BaseModel):
|
||||
data: dict # type: ignore
|
||||
|
||||
# Plain dict annotation doesn't match dict check, returns str
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["data"])
|
||||
assert type_name == "str" # Falls back to str for untyped dict
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_optional_type(self):
|
||||
class Model(BaseModel):
|
||||
optional: str | None = None
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["optional"])
|
||||
# Should unwrap Optional and get str
|
||||
assert type_name == "str"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_nested_model_type(self):
|
||||
class Inner(BaseModel):
|
||||
x: int
|
||||
|
||||
class Outer(BaseModel):
|
||||
nested: Inner
|
||||
|
||||
type_name, inner = _get_field_type_info(Outer.model_fields["nested"])
|
||||
assert type_name == "model"
|
||||
assert inner is Inner
|
||||
|
||||
def test_handles_none_annotation(self):
|
||||
"""Field with None annotation defaults to str."""
|
||||
class Model(BaseModel):
|
||||
field: Any = None
|
||||
|
||||
# Create a mock field_info with None annotation
|
||||
field_info = SimpleNamespace(annotation=None)
|
||||
type_name, inner = _get_field_type_info(field_info)
|
||||
assert type_name == "str"
|
||||
assert inner is None
|
||||
|
||||
|
||||
class TestGetFieldDisplayName:
|
||||
"""Tests for _get_field_display_name human-readable name generation."""
|
||||
|
||||
def test_uses_description_if_present(self):
|
||||
class Model(BaseModel):
|
||||
api_key: str = Field(description="API Key for authentication")
|
||||
|
||||
name = _get_field_display_name("api_key", Model.model_fields["api_key"])
|
||||
assert name == "API Key for authentication"
|
||||
|
||||
def test_converts_snake_case_to_title(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("user_name", field_info)
|
||||
assert name == "User Name"
|
||||
|
||||
def test_adds_url_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("api_url", field_info)
|
||||
# Title case: "Api Url"
|
||||
assert "Url" in name and "Api" in name
|
||||
|
||||
def test_adds_path_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("file_path", field_info)
|
||||
assert "Path" in name and "File" in name
|
||||
|
||||
def test_adds_id_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("user_id", field_info)
|
||||
# Title case: "User Id"
|
||||
assert "Id" in name and "User" in name
|
||||
|
||||
def test_adds_key_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("api_key", field_info)
|
||||
assert "Key" in name and "Api" in name
|
||||
|
||||
def test_adds_token_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("auth_token", field_info)
|
||||
assert "Token" in name and "Auth" in name
|
||||
|
||||
def test_adds_seconds_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("timeout_s", field_info)
|
||||
# Contains "(Seconds)" with title case
|
||||
assert "(Seconds)" in name or "(seconds)" in name
|
||||
|
||||
def test_adds_ms_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("delay_ms", field_info)
|
||||
# Contains "(Ms)" or "(ms)"
|
||||
assert "(Ms)" in name or "(ms)" in name
|
||||
|
||||
|
||||
class TestFormatValue:
|
||||
"""Tests for _format_value display formatting."""
|
||||
|
||||
def test_formats_none_as_not_set(self):
|
||||
assert "not set" in _format_value(None)
|
||||
|
||||
def test_formats_empty_string_as_not_set(self):
|
||||
assert "not set" in _format_value("")
|
||||
|
||||
def test_formats_empty_dict_as_not_set(self):
|
||||
assert "not set" in _format_value({})
|
||||
|
||||
def test_formats_empty_list_as_not_set(self):
|
||||
assert "not set" in _format_value([])
|
||||
|
||||
def test_formats_string_value(self):
|
||||
result = _format_value("hello")
|
||||
assert "hello" in result
|
||||
|
||||
def test_formats_list_value(self):
|
||||
result = _format_value(["a", "b"])
|
||||
assert "a" in result or "b" in result
|
||||
|
||||
def test_formats_dict_value(self):
|
||||
result = _format_value({"key": "value"})
|
||||
assert "key" in result or "value" in result
|
||||
|
||||
def test_formats_int_value(self):
|
||||
result = _format_value(42)
|
||||
assert "42" in result
|
||||
|
||||
def test_formats_bool_true(self):
|
||||
result = _format_value(True)
|
||||
assert "true" in result.lower() or "✓" in result
|
||||
|
||||
def test_formats_bool_false(self):
|
||||
result = _format_value(False)
|
||||
assert "false" in result.lower() or "✗" in result
|
||||
|
||||
|
||||
class TestSyncWorkspaceTemplates:
|
||||
"""Tests for sync_workspace_templates file synchronization."""
|
||||
|
||||
def test_creates_missing_files(self, tmp_path):
|
||||
"""Should create template files that don't exist."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
added = sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
# Check that some files were created
|
||||
assert isinstance(added, list)
|
||||
# The actual files depend on the templates directory
|
||||
|
||||
def test_does_not_overwrite_existing_files(self, tmp_path):
|
||||
"""Should not overwrite files that already exist."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir(parents=True)
|
||||
(workspace / "AGENTS.md").write_text("existing content")
|
||||
|
||||
sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
# Existing file should not be changed
|
||||
content = (workspace / "AGENTS.md").read_text()
|
||||
assert content == "existing content"
|
||||
|
||||
def test_creates_memory_directory(self, tmp_path):
|
||||
"""Should create memory directory structure."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
assert (workspace / "memory").exists() or (workspace / "skills").exists()
|
||||
|
||||
def test_returns_list_of_added_files(self, tmp_path):
|
||||
"""Should return list of relative paths for added files."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
added = sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
assert isinstance(added, list)
|
||||
# All paths should be relative to workspace
|
||||
for path in added:
|
||||
assert not Path(path).is_absolute()
|
||||
|
||||
|
||||
class TestProviderChannelInfo:
|
||||
"""Tests for provider and channel info retrieval."""
|
||||
|
||||
def test_get_provider_names_returns_dict(self):
|
||||
from nanobot.cli.onboard import _get_provider_names
|
||||
|
||||
names = _get_provider_names()
|
||||
assert isinstance(names, dict)
|
||||
assert len(names) > 0
|
||||
# Should include common providers
|
||||
assert "openai" in names or "anthropic" in names
|
||||
assert "openai_codex" not in names
|
||||
assert "github_copilot" not in names
|
||||
|
||||
def test_get_channel_names_returns_dict(self):
|
||||
from nanobot.cli.onboard import _get_channel_names
|
||||
|
||||
names = _get_channel_names()
|
||||
assert isinstance(names, dict)
|
||||
# Should include at least some channels
|
||||
assert len(names) >= 0
|
||||
|
||||
def test_get_provider_info_returns_valid_structure(self):
|
||||
from nanobot.cli.onboard import _get_provider_info
|
||||
|
||||
info = _get_provider_info()
|
||||
assert isinstance(info, dict)
|
||||
# Each value should be a tuple with expected structure
|
||||
for provider_name, value in info.items():
|
||||
assert isinstance(value, tuple)
|
||||
assert len(value) == 4 # (display_name, needs_api_key, needs_api_base, env_var)
|
||||
|
||||
|
||||
class _SimpleDraftModel(BaseModel):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class _NestedDraftModel(BaseModel):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class _OuterDraftModel(BaseModel):
|
||||
nested: _NestedDraftModel = Field(default_factory=_NestedDraftModel)
|
||||
|
||||
|
||||
class TestConfigurePydanticModelDrafts:
|
||||
@staticmethod
|
||||
def _patch_prompt_helpers(monkeypatch, tokens, text_value="secret"):
|
||||
sequence = iter(tokens)
|
||||
|
||||
def fake_select(_prompt, choices, default=None):
|
||||
token = next(sequence)
|
||||
if token == "first":
|
||||
return choices[0]
|
||||
if token == "done":
|
||||
return "[Done]"
|
||||
if token == "back":
|
||||
return _BACK_PRESSED
|
||||
return token
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
|
||||
monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
onboard_wizard, "_input_with_existing", lambda *_args, **_kwargs: text_value
|
||||
)
|
||||
|
||||
def test_discarding_section_keeps_original_model_unchanged(self, monkeypatch):
|
||||
model = _SimpleDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "back"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Simple")
|
||||
|
||||
assert result is None
|
||||
assert model.api_key == ""
|
||||
|
||||
def test_completing_section_returns_updated_draft(self, monkeypatch):
|
||||
model = _SimpleDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "done"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Simple")
|
||||
|
||||
assert result is not None
|
||||
updated = cast(_SimpleDraftModel, result)
|
||||
assert updated.api_key == "secret"
|
||||
assert model.api_key == ""
|
||||
|
||||
def test_nested_section_back_discards_nested_edits(self, monkeypatch):
|
||||
model = _OuterDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "first", "back", "done"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Outer")
|
||||
|
||||
assert result is not None
|
||||
updated = cast(_OuterDraftModel, result)
|
||||
assert updated.nested.api_key == ""
|
||||
assert model.nested.api_key == ""
|
||||
|
||||
def test_nested_section_done_commits_nested_edits(self, monkeypatch):
|
||||
model = _OuterDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "first", "done", "done"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Outer")
|
||||
|
||||
assert result is not None
|
||||
updated = cast(_OuterDraftModel, result)
|
||||
assert updated.nested.api_key == "secret"
|
||||
assert model.nested.api_key == ""
|
||||
|
||||
|
||||
class TestRunOnboardExitBehavior:
|
||||
def test_main_menu_interrupt_can_discard_unsaved_session_changes(self, monkeypatch):
|
||||
initial_config = Config()
|
||||
|
||||
responses = iter(
|
||||
[
|
||||
"[A] Agent Settings",
|
||||
KeyboardInterrupt(),
|
||||
"[X] Exit Without Saving",
|
||||
]
|
||||
)
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
def fake_configure_general_settings(config, section):
|
||||
if section == "Agent Settings":
|
||||
config.agents.defaults.model = "test/provider-model"
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None)
|
||||
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||
monkeypatch.setattr(onboard_wizard, "_configure_general_settings", fake_configure_general_settings)
|
||||
|
||||
result = run_onboard(initial_config=initial_config)
|
||||
|
||||
assert result.should_save is False
|
||||
assert result.config.model_dump(by_alias=True) == initial_config.model_dump(by_alias=True)
|
||||
@@ -0,0 +1,335 @@
|
||||
"""Tests for the shared agent runner and its integration contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_preserves_reasoning_fields_and_tool_results():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "do task"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert result.tools_used == ["list_dir"]
|
||||
assert result.tool_events == [
|
||||
{"name": "list_dir", "status": "ok", "detail": "tool result"}
|
||||
]
|
||||
|
||||
assistant_messages = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
assert any(
|
||||
msg.get("role") == "tool" and msg.get("content") == "tool result"
|
||||
for msg in captured_second_call
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_calls_hooks_in_order():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
events: list[tuple] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
class RecordingHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append(("before_iteration", context.iteration))
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
events.append((
|
||||
"before_execute_tools",
|
||||
context.iteration,
|
||||
[tc.name for tc in context.tool_calls],
|
||||
))
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append((
|
||||
"after_iteration",
|
||||
context.iteration,
|
||||
context.final_content,
|
||||
list(context.tool_results),
|
||||
list(context.tool_events),
|
||||
context.stop_reason,
|
||||
))
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
events.append(("finalize_content", context.iteration, content))
|
||||
return content.upper() if content else content
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
hook=RecordingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "DONE"
|
||||
assert events == [
|
||||
("before_iteration", 0),
|
||||
("before_execute_tools", 0, ["list_dir"]),
|
||||
(
|
||||
"after_iteration",
|
||||
0,
|
||||
None,
|
||||
["tool result"],
|
||||
[{"name": "list_dir", "status": "ok", "detail": "tool result"}],
|
||||
None,
|
||||
),
|
||||
("before_iteration", 1),
|
||||
("finalize_content", 1, "done"),
|
||||
("after_iteration", 1, "DONE", [], [], "completed"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
streamed: list[str] = []
|
||||
endings: list[bool] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("he")
|
||||
await on_content_delta("llo")
|
||||
return LLMResponse(content="hello", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
provider.chat_with_retry = AsyncMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class StreamingHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
streamed.append(delta)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
hook=StreamingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "hello"
|
||||
assert streamed == ["he", "llo"]
|
||||
assert endings == [False]
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_max_iterations_fallback():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="still working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "max_iterations"
|
||||
assert result.final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_structured_tool_error():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "tool_error"
|
||||
assert result.error == "Error: RuntimeError: boom"
|
||||
assert result.tool_events == [
|
||||
{"name": "list_dir", "status": "error", "detail": "boom"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
deltas: list[str] = []
|
||||
endings: list[bool] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("<think>hidden")
|
||||
await on_content_delta("</think>Hello")
|
||||
return LLMResponse(content="<think>hidden</think>Hello", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop(
|
||||
[],
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
assert final_content == "Hello"
|
||||
assert deltas == ["Hello"]
|
||||
assert endings == [False]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
mgr._announce_result.assert_awaited_once()
|
||||
args = mgr._announce_result.await_args.args
|
||||
assert args[3] == "Task completed but no final response was generated."
|
||||
assert args[5] == "ok"
|
||||
@@ -0,0 +1,198 @@
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
|
||||
def _assert_no_orphans(history: list[dict]) -> None:
|
||||
"""Assert every tool result in history has a matching assistant tool_call."""
|
||||
declared = {
|
||||
tc["id"]
|
||||
for m in history if m.get("role") == "assistant"
|
||||
for tc in (m.get("tool_calls") or [])
|
||||
}
|
||||
orphans = [
|
||||
m.get("tool_call_id") for m in history
|
||||
if m.get("role") == "tool" and m.get("tool_call_id") not in declared
|
||||
]
|
||||
assert orphans == [], f"orphan tool_call_ids: {orphans}"
|
||||
|
||||
|
||||
def _tool_turn(prefix: str, idx: int) -> list[dict]:
|
||||
"""Helper: one assistant with 2 tool_calls + 2 tool results."""
|
||||
return [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{"id": f"{prefix}_{idx}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
|
||||
{"id": f"{prefix}_{idx}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_a", "name": "x", "content": "ok"},
|
||||
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_b", "name": "y", "content": "ok"},
|
||||
]
|
||||
|
||||
|
||||
# --- Original regression test (from PR 2075) ---
|
||||
|
||||
def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls():
|
||||
session = Session(key="telegram:test")
|
||||
session.messages.append({"role": "user", "content": "old turn"})
|
||||
for i in range(20):
|
||||
session.messages.extend(_tool_turn("old", i))
|
||||
session.messages.append({"role": "user", "content": "problem turn"})
|
||||
for i in range(25):
|
||||
session.messages.extend(_tool_turn("cur", i))
|
||||
session.messages.append({"role": "user", "content": "new telegram question"})
|
||||
|
||||
history = session.get_history(max_messages=100)
|
||||
_assert_no_orphans(history)
|
||||
|
||||
|
||||
# --- Positive test: legitimate pairs survive trimming ---
|
||||
|
||||
def test_legitimate_tool_pairs_preserved_after_trim():
|
||||
"""Complete tool-call groups within the window must not be dropped."""
|
||||
session = Session(key="test:positive")
|
||||
session.messages.append({"role": "user", "content": "hello"})
|
||||
for i in range(5):
|
||||
session.messages.extend(_tool_turn("ok", i))
|
||||
session.messages.append({"role": "assistant", "content": "done"})
|
||||
|
||||
history = session.get_history(max_messages=500)
|
||||
_assert_no_orphans(history)
|
||||
tool_ids = [m["tool_call_id"] for m in history if m.get("role") == "tool"]
|
||||
assert len(tool_ids) == 10
|
||||
assert history[0]["role"] == "user"
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_keeps_recent_messages():
|
||||
session = Session(key="test:trim")
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||
|
||||
session.retain_recent_legal_suffix(4)
|
||||
|
||||
assert len(session.messages) == 4
|
||||
assert session.messages[0]["content"] == "msg6"
|
||||
assert session.messages[-1]["content"] == "msg9"
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_adjusts_last_consolidated():
|
||||
session = Session(key="test:trim-cons")
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||
session.last_consolidated = 7
|
||||
|
||||
session.retain_recent_legal_suffix(4)
|
||||
|
||||
assert len(session.messages) == 4
|
||||
assert session.last_consolidated == 1
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_zero_clears_session():
|
||||
session = Session(key="test:trim-zero")
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||
session.last_consolidated = 5
|
||||
|
||||
session.retain_recent_legal_suffix(0)
|
||||
|
||||
assert session.messages == []
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_keeps_legal_tool_boundary():
|
||||
session = Session(key="test:trim-tools")
|
||||
session.messages.append({"role": "user", "content": "old"})
|
||||
session.messages.extend(_tool_turn("old", 0))
|
||||
session.messages.append({"role": "user", "content": "keep"})
|
||||
session.messages.extend(_tool_turn("keep", 0))
|
||||
session.messages.append({"role": "assistant", "content": "done"})
|
||||
|
||||
session.retain_recent_legal_suffix(4)
|
||||
|
||||
history = session.get_history(max_messages=500)
|
||||
_assert_no_orphans(history)
|
||||
assert history[0]["role"] == "user"
|
||||
assert history[0]["content"] == "keep"
|
||||
|
||||
|
||||
# --- last_consolidated > 0 ---
|
||||
|
||||
def test_orphan_trim_with_last_consolidated():
|
||||
"""Orphan trimming works correctly when session is partially consolidated."""
|
||||
session = Session(key="test:consolidated")
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "user", "content": f"old {i}"})
|
||||
session.messages.extend(_tool_turn("cons", i))
|
||||
session.last_consolidated = 30
|
||||
|
||||
session.messages.append({"role": "user", "content": "recent"})
|
||||
for i in range(15):
|
||||
session.messages.extend(_tool_turn("new", i))
|
||||
session.messages.append({"role": "user", "content": "latest"})
|
||||
|
||||
history = session.get_history(max_messages=20)
|
||||
_assert_no_orphans(history)
|
||||
assert all(m.get("role") != "tool" or m["tool_call_id"].startswith("new_") for m in history)
|
||||
|
||||
|
||||
# --- Edge: no tool messages at all ---
|
||||
|
||||
def test_no_tool_messages_unchanged():
|
||||
session = Session(key="test:plain")
|
||||
for i in range(5):
|
||||
session.messages.append({"role": "user", "content": f"q{i}"})
|
||||
session.messages.append({"role": "assistant", "content": f"a{i}"})
|
||||
|
||||
history = session.get_history(max_messages=6)
|
||||
assert len(history) == 6
|
||||
_assert_no_orphans(history)
|
||||
|
||||
|
||||
# --- Edge: all leading messages are orphan tool results ---
|
||||
|
||||
def test_all_orphan_prefix_stripped():
|
||||
"""If the window starts with orphan tool results and nothing else, they're all dropped."""
|
||||
session = Session(key="test:all-orphan")
|
||||
session.messages.append({"role": "tool", "tool_call_id": "gone_1", "name": "x", "content": "ok"})
|
||||
session.messages.append({"role": "tool", "tool_call_id": "gone_2", "name": "y", "content": "ok"})
|
||||
session.messages.append({"role": "user", "content": "fresh start"})
|
||||
session.messages.append({"role": "assistant", "content": "hi"})
|
||||
|
||||
history = session.get_history(max_messages=500)
|
||||
_assert_no_orphans(history)
|
||||
assert history[0]["role"] == "user"
|
||||
assert len(history) == 2
|
||||
|
||||
|
||||
# --- Edge: empty session ---
|
||||
|
||||
def test_empty_session_history():
|
||||
session = Session(key="test:empty")
|
||||
history = session.get_history(max_messages=500)
|
||||
assert history == []
|
||||
|
||||
|
||||
# --- Window cuts mid-group: assistant present but some tool results orphaned ---
|
||||
|
||||
def test_window_cuts_mid_tool_group():
|
||||
"""If the window starts between an assistant's tool results, the partial group is trimmed."""
|
||||
session = Session(key="test:mid-cut")
|
||||
session.messages.append({"role": "user", "content": "setup"})
|
||||
session.messages.append({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [
|
||||
{"id": "split_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
|
||||
{"id": "split_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
|
||||
],
|
||||
})
|
||||
session.messages.append({"role": "tool", "tool_call_id": "split_a", "name": "x", "content": "ok"})
|
||||
session.messages.append({"role": "tool", "tool_call_id": "split_b", "name": "y", "content": "ok"})
|
||||
session.messages.append({"role": "user", "content": "next"})
|
||||
session.messages.extend(_tool_turn("intact", 0))
|
||||
session.messages.append({"role": "assistant", "content": "final"})
|
||||
|
||||
# Window of 6 should cut off the "setup" user msg and the assistant with split_a/split_b,
|
||||
# leaving orphan tool results for split_a at the front.
|
||||
history = session.get_history(max_messages=6)
|
||||
_assert_no_orphans(history)
|
||||
@@ -0,0 +1,127 @@
|
||||
import importlib
|
||||
import shutil
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve()
|
||||
if str(SCRIPT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(SCRIPT_DIR))
|
||||
|
||||
init_skill = importlib.import_module("init_skill")
|
||||
package_skill = importlib.import_module("package_skill")
|
||||
quick_validate = importlib.import_module("quick_validate")
|
||||
|
||||
|
||||
def test_init_skill_creates_expected_files(tmp_path: Path) -> None:
|
||||
skill_dir = init_skill.init_skill(
|
||||
"demo-skill",
|
||||
tmp_path,
|
||||
["scripts", "references", "assets"],
|
||||
include_examples=True,
|
||||
)
|
||||
|
||||
assert skill_dir == tmp_path / "demo-skill"
|
||||
assert (skill_dir / "SKILL.md").exists()
|
||||
assert (skill_dir / "scripts" / "example.py").exists()
|
||||
assert (skill_dir / "references" / "api_reference.md").exists()
|
||||
assert (skill_dir / "assets" / "example_asset.txt").exists()
|
||||
|
||||
|
||||
def test_validate_skill_accepts_existing_skill_creator() -> None:
|
||||
valid, message = quick_validate.validate_skill(
|
||||
Path("nanobot/skills/skill-creator").resolve()
|
||||
)
|
||||
|
||||
assert valid, message
|
||||
|
||||
|
||||
def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "placeholder-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: placeholder-skill\n"
|
||||
'description: "[TODO: fill me in]"\n'
|
||||
"---\n"
|
||||
"# Placeholder\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
valid, message = quick_validate.validate_skill(skill_dir)
|
||||
|
||||
assert not valid
|
||||
assert "TODO placeholder" in message
|
||||
|
||||
|
||||
def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "bad-root-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: bad-root-skill\n"
|
||||
"description: Valid description\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(skill_dir / "README.md").write_text("extra\n", encoding="utf-8")
|
||||
|
||||
valid, message = quick_validate.validate_skill(skill_dir)
|
||||
|
||||
assert not valid
|
||||
assert "Unexpected file or directory in skill root" in message
|
||||
|
||||
|
||||
def test_package_skill_creates_archive(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "package-me"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: package-me\n"
|
||||
"description: Package this skill.\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8")
|
||||
|
||||
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||
|
||||
assert archive_path == (tmp_path / "dist" / "package-me.skill")
|
||||
assert archive_path.exists()
|
||||
with zipfile.ZipFile(archive_path, "r") as archive:
|
||||
names = set(archive.namelist())
|
||||
assert "package-me/SKILL.md" in names
|
||||
assert "package-me/scripts/helper.py" in names
|
||||
|
||||
|
||||
def test_package_skill_rejects_symlink(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "symlink-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: symlink-skill\n"
|
||||
"description: Reject symlinks during packaging.\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
target = tmp_path / "outside.txt"
|
||||
target.write_text("secret\n", encoding="utf-8")
|
||||
link = scripts_dir / "outside.txt"
|
||||
|
||||
try:
|
||||
link.symlink_to(target)
|
||||
except (OSError, NotImplementedError):
|
||||
return
|
||||
|
||||
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||
|
||||
assert archive_path is None
|
||||
assert not (tmp_path / "dist" / "symlink-skill.skill").exists()
|
||||
@@ -0,0 +1,303 @@
|
||||
"""Tests for /stop task cancellation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_loop(*, exec_config=None):
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
workspace = MagicMock()
|
||||
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace, exec_config=exec_config)
|
||||
return loop, bus
|
||||
|
||||
|
||||
class TestHandleStop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_no_active_task(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
from nanobot.command.router import CommandContext
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw="/stop", loop=loop)
|
||||
out = await cmd_stop(ctx)
|
||||
assert "No active task" in out.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_active_task(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
from nanobot.command.router import CommandContext
|
||||
|
||||
loop, bus = _make_loop()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def slow_task():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow_task())
|
||||
await asyncio.sleep(0)
|
||||
loop._active_tasks["test:c1"] = [task]
|
||||
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw="/stop", loop=loop)
|
||||
out = await cmd_stop(ctx)
|
||||
|
||||
assert cancelled.is_set()
|
||||
assert "stopped" in out.content.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_multiple_tasks(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
from nanobot.command.router import CommandContext
|
||||
|
||||
loop, bus = _make_loop()
|
||||
events = [asyncio.Event(), asyncio.Event()]
|
||||
|
||||
async def slow(idx):
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
events[idx].set()
|
||||
raise
|
||||
|
||||
tasks = [asyncio.create_task(slow(i)) for i in range(2)]
|
||||
await asyncio.sleep(0)
|
||||
loop._active_tasks["test:c1"] = tasks
|
||||
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw="/stop", loop=loop)
|
||||
out = await cmd_stop(ctx)
|
||||
|
||||
assert all(e.is_set() for e in events)
|
||||
assert "2 task" in out.content
|
||||
|
||||
|
||||
class TestDispatch:
|
||||
def test_exec_tool_not_registered_when_disabled(self):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
|
||||
loop, _bus = _make_loop(exec_config=ExecToolConfig(enable=False))
|
||||
|
||||
assert loop.tools.get("exec") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_processes_and_publishes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="hello")
|
||||
loop._process_message = AsyncMock(
|
||||
return_value=OutboundMessage(channel="test", chat_id="c1", content="hi")
|
||||
)
|
||||
await loop._dispatch(msg)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert out.content == "hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_lock_serializes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
order = []
|
||||
|
||||
async def mock_process(m, **kwargs):
|
||||
order.append(f"start-{m.content}")
|
||||
await asyncio.sleep(0.05)
|
||||
order.append(f"end-{m.content}")
|
||||
return OutboundMessage(channel="test", chat_id="c1", content=m.content)
|
||||
|
||||
loop._process_message = mock_process
|
||||
msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a")
|
||||
msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b")
|
||||
|
||||
t1 = asyncio.create_task(loop._dispatch(msg1))
|
||||
t2 = asyncio.create_task(loop._dispatch(msg2))
|
||||
await asyncio.gather(t1, t2)
|
||||
assert order == ["start-a", "end-a", "start-b", "end-b"]
|
||||
|
||||
|
||||
class TestSubagentCancellation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session(self):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def slow():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow())
|
||||
await asyncio.sleep(0)
|
||||
mgr._running_tasks["sub-1"] = task
|
||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||
|
||||
count = await mgr.cancel_by_session("test:c1")
|
||||
assert count == 1
|
||||
assert cancelled.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session_no_tasks(self):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
captured_second_call: list[dict] = []
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def scripted_chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[])
|
||||
provider.chat_with_retry = scripted_chat_with_retry
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
assistant_messages = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
return "first result"
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
mgr._announce_result.assert_awaited_once()
|
||||
args = mgr._announce_result.await_args.args
|
||||
assert "Completed steps:" in args[3]
|
||||
assert "- list_dir: first result" in args[3]
|
||||
assert "Failure:" in args[3]
|
||||
assert "- list_dir: boom" in args[3]
|
||||
assert args[5] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session_cancels_running_subagent_tool(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
started = asyncio.Event()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
started.set()
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
task = asyncio.create_task(
|
||||
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
)
|
||||
mgr._running_tasks["sub-1"] = task
|
||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||
|
||||
await started.wait()
|
||||
|
||||
count = await mgr.cancel_by_session("test:c1")
|
||||
|
||||
assert count == 1
|
||||
assert cancelled.is_set()
|
||||
assert task.cancelled()
|
||||
mgr._announce_result.assert_not_awaited()
|
||||
@@ -0,0 +1,25 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
|
||||
|
||||
class _DummyChannel(BaseChannel):
|
||||
name = "dummy"
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop(self) -> None:
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_is_allowed_requires_exact_match() -> None:
|
||||
channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("allow@email.com") is True
|
||||
assert channel.is_allowed("attacker|allow@email.com") is False
|
||||
@@ -0,0 +1,298 @@
|
||||
"""Tests for ChannelManager delta coalescing to reduce streaming latency."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
class MockChannel(BaseChannel):
|
||||
"""Mock channel for testing."""
|
||||
|
||||
name = "mock"
|
||||
display_name = "Mock"
|
||||
|
||||
def __init__(self, config, bus):
|
||||
super().__init__(config, bus)
|
||||
self._send_delta_mock = AsyncMock()
|
||||
self._send_mock = AsyncMock()
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
async def stop(self):
|
||||
pass
|
||||
|
||||
async def send(self, msg):
|
||||
"""Implement abstract method."""
|
||||
return await self._send_mock(msg)
|
||||
|
||||
async def send_delta(self, chat_id, delta, metadata=None):
|
||||
"""Override send_delta for testing."""
|
||||
return await self._send_delta_mock(chat_id, delta, metadata)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
"""Create a minimal config for testing."""
|
||||
return Config()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bus():
|
||||
"""Create a message bus for testing."""
|
||||
return MessageBus()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(config, bus):
|
||||
"""Create a channel manager with a mock channel."""
|
||||
manager = ChannelManager(config, bus)
|
||||
manager.channels["mock"] = MockChannel({}, bus)
|
||||
return manager
|
||||
|
||||
|
||||
class TestDeltaCoalescing:
|
||||
"""Tests for _stream_delta message coalescing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_delta_not_coalesced(self, manager, bus):
|
||||
"""A single delta should be sent as-is."""
|
||||
msg = OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Hello",
|
||||
metadata={"_stream_delta": True},
|
||||
)
|
||||
await bus.publish_outbound(msg)
|
||||
|
||||
# Process one message
|
||||
async def process_one():
|
||||
try:
|
||||
m = await asyncio.wait_for(bus.consume_outbound(), timeout=0.1)
|
||||
if m.metadata.get("_stream_delta"):
|
||||
m, pending = manager._coalesce_stream_deltas(m)
|
||||
# Put pending back (none expected)
|
||||
for p in pending:
|
||||
await bus.publish_outbound(p)
|
||||
channel = manager.channels.get(m.channel)
|
||||
if channel:
|
||||
await channel.send_delta(m.chat_id, m.content, m.metadata)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
await process_one()
|
||||
|
||||
manager.channels["mock"]._send_delta_mock.assert_called_once_with(
|
||||
"chat1", "Hello", {"_stream_delta": True}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_deltas_coalesced(self, manager, bus):
|
||||
"""Multiple consecutive deltas for same chat should be merged."""
|
||||
# Put multiple deltas in queue
|
||||
for text in ["Hello", " ", "world", "!"]:
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content=text,
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
|
||||
# Process using coalescing logic
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
# Should have merged all deltas
|
||||
assert merged.content == "Hello world!"
|
||||
assert merged.metadata.get("_stream_delta") is True
|
||||
# No pending messages (all were coalesced)
|
||||
assert len(pending) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deltas_different_chats_not_coalesced(self, manager, bus):
|
||||
"""Deltas for different chats should not be merged."""
|
||||
# Put deltas for different chats
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Hello",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat2",
|
||||
content="World",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
# First chat should not include second chat's content
|
||||
assert merged.content == "Hello"
|
||||
assert merged.chat_id == "chat1"
|
||||
# Second chat should be in pending
|
||||
assert len(pending) == 1
|
||||
assert pending[0].chat_id == "chat2"
|
||||
assert pending[0].content == "World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_terminates_coalescing(self, manager, bus):
|
||||
"""_stream_end should stop coalescing and be included in final message."""
|
||||
# Put deltas with stream_end at the end
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Hello",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content=" world",
|
||||
metadata={"_stream_delta": True, "_stream_end": True},
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
# Should have merged content
|
||||
assert merged.content == "Hello world"
|
||||
# Should have stream_end flag
|
||||
assert merged.metadata.get("_stream_end") is True
|
||||
# No pending
|
||||
assert len(pending) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_coalescing_stops_at_first_non_matching_boundary(self, manager, bus):
|
||||
"""Only consecutive deltas should be merged; later deltas stay queued."""
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Hello",
|
||||
metadata={"_stream_delta": True, "_stream_id": "seg-1"},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="",
|
||||
metadata={"_stream_end": True, "_stream_id": "seg-1"},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="world",
|
||||
metadata={"_stream_delta": True, "_stream_id": "seg-2"},
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
assert merged.content == "Hello"
|
||||
assert merged.metadata.get("_stream_end") is None
|
||||
assert len(pending) == 1
|
||||
assert pending[0].metadata.get("_stream_end") is True
|
||||
assert pending[0].metadata.get("_stream_id") == "seg-1"
|
||||
|
||||
# The next stream segment must remain in queue order for later dispatch.
|
||||
remaining = await bus.consume_outbound()
|
||||
assert remaining.content == "world"
|
||||
assert remaining.metadata.get("_stream_id") == "seg-2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_delta_message_preserved(self, manager, bus):
|
||||
"""Non-delta messages should be preserved in pending list."""
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Delta",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Final message",
|
||||
metadata={}, # Not a delta
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
assert merged.content == "Delta"
|
||||
assert len(pending) == 1
|
||||
assert pending[0].content == "Final message"
|
||||
assert pending[0].metadata.get("_stream_delta") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_queue_stops_coalescing(self, manager, bus):
|
||||
"""Coalescing should stop when queue is empty."""
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Only message",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
assert merged.content == "Only message"
|
||||
assert len(pending) == 0
|
||||
|
||||
|
||||
class TestDispatchOutboundWithCoalescing:
|
||||
"""Tests for the full _dispatch_outbound flow with coalescing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_coalesces_and_processes_pending(self, manager, bus):
|
||||
"""_dispatch_outbound should coalesce deltas and process pending messages."""
|
||||
# Put multiple deltas followed by a regular message
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="A",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="B",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Final",
|
||||
metadata={}, # Regular message
|
||||
))
|
||||
|
||||
# Run one iteration of dispatch logic manually
|
||||
pending = []
|
||||
processed = []
|
||||
|
||||
# First iteration: should coalesce A+B
|
||||
if pending:
|
||||
msg = pending.pop(0)
|
||||
else:
|
||||
msg = await bus.consume_outbound()
|
||||
|
||||
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):
|
||||
msg, extra_pending = manager._coalesce_stream_deltas(msg)
|
||||
pending.extend(extra_pending)
|
||||
|
||||
channel = manager.channels.get(msg.channel)
|
||||
if channel:
|
||||
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
|
||||
processed.append(("delta", msg.content))
|
||||
|
||||
# Should have sent coalesced delta
|
||||
assert processed == [("delta", "AB")]
|
||||
# Should have pending regular message
|
||||
assert len(pending) == 1
|
||||
assert pending[0].content == "Final"
|
||||
@@ -0,0 +1,880 @@
|
||||
"""Tests for channel plugin discovery, merging, and config compatibility."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.schema import ChannelsConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakePlugin(BaseChannel):
|
||||
name = "fakeplugin"
|
||||
display_name = "Fake Plugin"
|
||||
|
||||
def __init__(self, config, bus):
|
||||
super().__init__(config, bus)
|
||||
self.login_calls: list[bool] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
async def login(self, force: bool = False) -> bool:
|
||||
self.login_calls.append(force)
|
||||
return True
|
||||
|
||||
|
||||
class _FakeTelegram(BaseChannel):
|
||||
"""Plugin that tries to shadow built-in telegram."""
|
||||
name = "telegram"
|
||||
display_name = "Fake Telegram"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _make_entry_point(name: str, cls: type):
|
||||
"""Create a mock entry point that returns *cls* on load()."""
|
||||
ep = SimpleNamespace(name=name, load=lambda _cls=cls: _cls)
|
||||
return ep
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelsConfig extra="allow"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_channels_config_accepts_unknown_keys():
|
||||
cfg = ChannelsConfig.model_validate({
|
||||
"myplugin": {"enabled": True, "token": "abc"},
|
||||
})
|
||||
extra = cfg.model_extra
|
||||
assert extra is not None
|
||||
assert extra["myplugin"]["enabled"] is True
|
||||
assert extra["myplugin"]["token"] == "abc"
|
||||
|
||||
|
||||
def test_channels_config_getattr_returns_extra():
|
||||
cfg = ChannelsConfig.model_validate({"myplugin": {"enabled": True}})
|
||||
section = getattr(cfg, "myplugin", None)
|
||||
assert isinstance(section, dict)
|
||||
assert section["enabled"] is True
|
||||
|
||||
|
||||
def test_channels_config_builtin_fields_removed():
|
||||
"""After decoupling, ChannelsConfig has no explicit channel fields."""
|
||||
cfg = ChannelsConfig()
|
||||
assert not hasattr(cfg, "telegram")
|
||||
assert cfg.send_progress is True
|
||||
assert cfg.send_tool_hints is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# discover_plugins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EP_TARGET = "importlib.metadata.entry_points"
|
||||
|
||||
|
||||
def test_discover_plugins_loads_entry_points():
|
||||
from nanobot.channels.registry import discover_plugins
|
||||
|
||||
ep = _make_entry_point("line", _FakePlugin)
|
||||
with patch(_EP_TARGET, return_value=[ep]):
|
||||
result = discover_plugins()
|
||||
|
||||
assert "line" in result
|
||||
assert result["line"] is _FakePlugin
|
||||
|
||||
|
||||
def test_discover_plugins_handles_load_error():
|
||||
from nanobot.channels.registry import discover_plugins
|
||||
|
||||
def _boom():
|
||||
raise RuntimeError("broken")
|
||||
|
||||
ep = SimpleNamespace(name="broken", load=_boom)
|
||||
with patch(_EP_TARGET, return_value=[ep]):
|
||||
result = discover_plugins()
|
||||
|
||||
assert "broken" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# discover_all — merge & priority
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_discover_all_includes_builtins():
|
||||
from nanobot.channels.registry import discover_all, discover_channel_names
|
||||
|
||||
with patch(_EP_TARGET, return_value=[]):
|
||||
result = discover_all()
|
||||
|
||||
# discover_all() only returns channels that are actually available (dependencies installed)
|
||||
# discover_channel_names() returns all built-in channel names
|
||||
# So we check that all actually loaded channels are in the result
|
||||
for name in result:
|
||||
assert name in discover_channel_names()
|
||||
|
||||
|
||||
def test_discover_all_includes_external_plugin():
|
||||
from nanobot.channels.registry import discover_all
|
||||
|
||||
ep = _make_entry_point("line", _FakePlugin)
|
||||
with patch(_EP_TARGET, return_value=[ep]):
|
||||
result = discover_all()
|
||||
|
||||
assert "line" in result
|
||||
assert result["line"] is _FakePlugin
|
||||
|
||||
|
||||
def test_discover_all_builtin_shadows_plugin():
|
||||
from nanobot.channels.registry import discover_all
|
||||
|
||||
ep = _make_entry_point("telegram", _FakeTelegram)
|
||||
with patch(_EP_TARGET, return_value=[ep]):
|
||||
result = discover_all()
|
||||
|
||||
assert "telegram" in result
|
||||
assert result["telegram"] is not _FakeTelegram
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Manager _init_channels with dict config (plugin scenario)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_loads_plugin_from_dict_config():
|
||||
"""ChannelManager should instantiate a plugin channel from a raw dict config."""
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig.model_validate({
|
||||
"fakeplugin": {"enabled": True, "allowFrom": ["*"]},
|
||||
}),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
return_value={"fakeplugin": _FakePlugin},
|
||||
):
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {}
|
||||
mgr._dispatch_task = None
|
||||
mgr._init_channels()
|
||||
|
||||
assert "fakeplugin" in mgr.channels
|
||||
assert isinstance(mgr.channels["fakeplugin"], _FakePlugin)
|
||||
|
||||
|
||||
def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from typer.testing import CliRunner
|
||||
|
||||
runner = CliRunner()
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
class _LoginPlugin(_FakePlugin):
|
||||
display_name = "Login Plugin"
|
||||
|
||||
async def login(self, force: bool = False) -> bool:
|
||||
seen["force"] = force
|
||||
seen["config"] = self.config
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {"fakeplugin": _LoginPlugin},
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["channels", "login", "fakeplugin", "--force"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["force"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_skips_disabled_plugin():
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig.model_validate({
|
||||
"fakeplugin": {"enabled": False},
|
||||
}),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
return_value={"fakeplugin": _FakePlugin},
|
||||
):
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {}
|
||||
mgr._dispatch_task = None
|
||||
mgr._init_channels()
|
||||
|
||||
assert "fakeplugin" not in mgr.channels
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in channel default_config() and dict->Pydantic conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_builtin_channel_default_config():
|
||||
"""Built-in channels expose default_config() returning a dict with 'enabled': False."""
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
cfg = TelegramChannel.default_config()
|
||||
assert isinstance(cfg, dict)
|
||||
assert cfg["enabled"] is False
|
||||
assert "token" in cfg
|
||||
|
||||
|
||||
def test_builtin_channel_init_from_dict():
|
||||
"""Built-in channels accept a raw dict and convert to Pydantic internally."""
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
bus = MessageBus()
|
||||
ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus)
|
||||
assert ch.config.token == "test-tok"
|
||||
assert ch.config.allow_from == ["*"]
|
||||
|
||||
|
||||
def test_channels_config_send_max_retries_default():
|
||||
"""ChannelsConfig should have send_max_retries with default value of 3."""
|
||||
cfg = ChannelsConfig()
|
||||
assert hasattr(cfg, 'send_max_retries')
|
||||
assert cfg.send_max_retries == 3
|
||||
|
||||
|
||||
def test_channels_config_send_max_retries_upper_bound():
|
||||
"""send_max_retries should be bounded to prevent resource exhaustion."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
# Value too high should be rejected
|
||||
with pytest.raises(ValidationError):
|
||||
ChannelsConfig(send_max_retries=100)
|
||||
|
||||
# Negative should be rejected
|
||||
with pytest.raises(ValidationError):
|
||||
ChannelsConfig(send_max_retries=-1)
|
||||
|
||||
# Boundary values should be allowed
|
||||
cfg_min = ChannelsConfig(send_max_retries=0)
|
||||
assert cfg_min.send_max_retries == 0
|
||||
|
||||
cfg_max = ChannelsConfig(send_max_retries=10)
|
||||
assert cfg_max.send_max_retries == 10
|
||||
|
||||
# Value above upper bound should be rejected
|
||||
with pytest.raises(ValidationError):
|
||||
ChannelsConfig(send_max_retries=11)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_succeeds_first_try():
|
||||
"""_send_with_retry should succeed on first try and not retry."""
|
||||
call_count = 0
|
||||
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# Succeeds on first try
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
|
||||
await mgr._send_with_retry(mgr.channels["failing"], msg)
|
||||
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_retries_on_failure():
|
||||
"""_send_with_retry should retry on failure up to max_retries times."""
|
||||
call_count = 0
|
||||
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError("simulated failure")
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
|
||||
|
||||
# Patch asyncio.sleep to avoid actual delays
|
||||
with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
await mgr._send_with_retry(mgr.channels["failing"], msg)
|
||||
|
||||
assert call_count == 3 # 3 total attempts (initial + 2 retries)
|
||||
assert mock_sleep.call_count == 2 # 2 sleeps between retries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_no_retry_when_max_is_zero():
|
||||
"""_send_with_retry should not retry when send_max_retries is 0."""
|
||||
call_count = 0
|
||||
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError("simulated failure")
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=0),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
|
||||
|
||||
with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock):
|
||||
await mgr._send_with_retry(mgr.channels["failing"], msg)
|
||||
|
||||
assert call_count == 1 # Called once but no retry (max(0, 1) = 1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_calls_send_delta():
|
||||
"""_send_with_retry should call send_delta when metadata has _stream_delta."""
|
||||
send_delta_called = False
|
||||
|
||||
class _StreamingChannel(BaseChannel):
|
||||
name = "streaming"
|
||||
display_name = "Streaming"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass # Should not be called
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None:
|
||||
nonlocal send_delta_called
|
||||
send_delta_called = True
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"streaming": _StreamingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="streaming", chat_id="123", content="test delta",
|
||||
metadata={"_stream_delta": True}
|
||||
)
|
||||
await mgr._send_with_retry(mgr.channels["streaming"], msg)
|
||||
|
||||
assert send_delta_called is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_skips_send_when_streamed():
|
||||
"""_send_with_retry should not call send when metadata has _streamed flag."""
|
||||
send_called = False
|
||||
send_delta_called = False
|
||||
|
||||
class _StreamedChannel(BaseChannel):
|
||||
name = "streamed"
|
||||
display_name = "Streamed"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal send_called
|
||||
send_called = True
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None:
|
||||
nonlocal send_delta_called
|
||||
send_delta_called = True
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"streamed": _StreamedChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# _streamed means message was already sent via send_delta, so skip send
|
||||
msg = OutboundMessage(
|
||||
channel="streamed", chat_id="123", content="test",
|
||||
metadata={"_streamed": True}
|
||||
)
|
||||
await mgr._send_with_retry(mgr.channels["streamed"], msg)
|
||||
|
||||
assert send_called is False
|
||||
assert send_delta_called is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_propagates_cancelled_error():
|
||||
"""_send_with_retry should re-raise CancelledError for graceful shutdown."""
|
||||
class _CancellingChannel(BaseChannel):
|
||||
name = "cancelling"
|
||||
display_name = "Cancelling"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
raise asyncio.CancelledError("simulated cancellation")
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"cancelling": _CancellingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="cancelling", chat_id="123", content="test")
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await mgr._send_with_retry(mgr.channels["cancelling"], msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_propagates_cancelled_error_during_sleep():
|
||||
"""_send_with_retry should re-raise CancelledError during sleep."""
|
||||
call_count = 0
|
||||
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError("simulated failure")
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
|
||||
|
||||
# Mock sleep to raise CancelledError
|
||||
async def cancel_during_sleep(_):
|
||||
raise asyncio.CancelledError("cancelled during sleep")
|
||||
|
||||
with patch("nanobot.channels.manager.asyncio.sleep", side_effect=cancel_during_sleep):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await mgr._send_with_retry(mgr.channels["failing"], msg)
|
||||
|
||||
# Should have attempted once before sleep was cancelled
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelManager - lifecycle and getters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _ChannelWithAllowFrom(BaseChannel):
|
||||
"""Channel with configurable allow_from."""
|
||||
name = "withallow"
|
||||
display_name = "With Allow"
|
||||
|
||||
def __init__(self, config, bus, allow_from):
|
||||
super().__init__(config, bus)
|
||||
self.config.allow_from = allow_from
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _StartableChannel(BaseChannel):
|
||||
"""Channel that tracks start/stop calls."""
|
||||
name = "startable"
|
||||
display_name = "Startable"
|
||||
|
||||
def __init__(self, config, bus):
|
||||
super().__init__(config, bus)
|
||||
self.started = False
|
||||
self.stopped = False
|
||||
|
||||
async def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.stopped = True
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_allow_from_raises_on_empty_list():
|
||||
"""_validate_allow_from should raise SystemExit when allow_from is empty list."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
mgr._validate_allow_from()
|
||||
|
||||
assert "empty allowFrom" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_allow_from_passes_with_asterisk():
|
||||
"""_validate_allow_from should not raise when allow_from contains '*'."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, ["*"])}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# Should not raise
|
||||
mgr._validate_allow_from()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_channel_returns_channel_if_exists():
|
||||
"""get_channel should return the channel if it exists."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"telegram": _StartableChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
assert mgr.get_channel("telegram") is not None
|
||||
assert mgr.get_channel("nonexistent") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_status_returns_running_state():
|
||||
"""get_status should return enabled and running state for each channel."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
ch = _StartableChannel(fake_config, mgr.bus)
|
||||
mgr.channels = {"startable": ch}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
status = mgr.get_status()
|
||||
|
||||
assert status["startable"]["enabled"] is True
|
||||
assert status["startable"]["running"] is False # Not started yet
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enabled_channels_returns_channel_names():
|
||||
"""enabled_channels should return list of enabled channel names."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {
|
||||
"telegram": _StartableChannel(fake_config, mgr.bus),
|
||||
"slack": _StartableChannel(fake_config, mgr.bus),
|
||||
}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
enabled = mgr.enabled_channels
|
||||
|
||||
assert "telegram" in enabled
|
||||
assert "slack" in enabled
|
||||
assert len(enabled) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_cancels_dispatcher_and_stops_channels():
|
||||
"""stop_all should cancel the dispatch task and stop all channels."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
|
||||
ch = _StartableChannel(fake_config, mgr.bus)
|
||||
mgr.channels = {"startable": ch}
|
||||
|
||||
# Create a real cancelled task
|
||||
async def dummy_task():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
dispatch_task = asyncio.create_task(dummy_task())
|
||||
mgr._dispatch_task = dispatch_task
|
||||
|
||||
await mgr.stop_all()
|
||||
|
||||
# Task should be cancelled
|
||||
assert dispatch_task.cancelled()
|
||||
# Channel should be stopped
|
||||
assert ch.stopped is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_channel_logs_error_on_failure():
|
||||
"""_start_channel should log error when channel start fails."""
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
raise RuntimeError("connection failed")
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
ch = _FailingChannel(fake_config, mgr.bus)
|
||||
|
||||
# Should not raise, just log error
|
||||
await mgr._start_channel("failing", ch)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_handles_channel_exception():
|
||||
"""stop_all should handle exceptions when stopping channels gracefully."""
|
||||
class _StopFailingChannel(BaseChannel):
|
||||
name = "stopfailing"
|
||||
display_name = "Stop Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
raise RuntimeError("stop failed")
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"stopfailing": _StopFailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# Should not raise even if channel.stop() raises
|
||||
await mgr.stop_all()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_all_no_channels_logs_warning():
|
||||
"""start_all should log warning when no channels are enabled."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {} # No channels
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# Should return early without creating dispatch task
|
||||
await mgr.start_all()
|
||||
|
||||
assert mgr._dispatch_task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_all_creates_dispatch_task():
|
||||
"""start_all should create the dispatch task when channels exist."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
|
||||
ch = _StartableChannel(fake_config, mgr.bus)
|
||||
mgr.channels = {"startable": ch}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# Cancel immediately after start to avoid running forever
|
||||
async def cancel_after_start():
|
||||
await asyncio.sleep(0.01)
|
||||
if mgr._dispatch_task:
|
||||
mgr._dispatch_task.cancel()
|
||||
|
||||
cancel_task = asyncio.create_task(cancel_after_start())
|
||||
|
||||
try:
|
||||
await mgr.start_all()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
cancel_task.cancel()
|
||||
try:
|
||||
await cancel_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Dispatch task should have been created
|
||||
assert mgr._dispatch_task is not None
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional dingtalk dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import dingtalk
|
||||
DINGTALK_AVAILABLE = getattr(dingtalk, "DINGTALK_AVAILABLE", False)
|
||||
except ImportError:
|
||||
DINGTALK_AVAILABLE = False
|
||||
|
||||
if not DINGTALK_AVAILABLE:
|
||||
pytest.skip("DingTalk dependencies not installed (dingtalk-stream)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
import nanobot.channels.dingtalk as dingtalk_module
|
||||
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
|
||||
from nanobot.channels.dingtalk import DingTalkConfig
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
|
||||
self.status_code = status_code
|
||||
self._json_body = json_body or {}
|
||||
self.text = "{}"
|
||||
self.content = b""
|
||||
self.headers = {"content-type": "application/json"}
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._json_body
|
||||
|
||||
|
||||
class _FakeHttp:
|
||||
def __init__(self, responses: list[_FakeResponse] | None = None) -> None:
|
||||
self.calls: list[dict] = []
|
||||
self._responses = list(responses) if responses else []
|
||||
|
||||
def _next_response(self) -> _FakeResponse:
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return _FakeResponse()
|
||||
|
||||
async def post(self, url: str, json=None, headers=None, **kwargs):
|
||||
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
|
||||
return self._next_response()
|
||||
|
||||
async def get(self, url: str, **kwargs):
|
||||
self.calls.append({"method": "GET", "url": url})
|
||||
return self._next_response()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(config, bus)
|
||||
|
||||
await channel._on_message(
|
||||
"hello",
|
||||
sender_id="user1",
|
||||
sender_name="Alice",
|
||||
conversation_type="2",
|
||||
conversation_id="conv123",
|
||||
)
|
||||
|
||||
msg = await bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
assert msg.metadata["conversation_type"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_send_uses_group_messages_api() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
channel._http = _FakeHttp()
|
||||
|
||||
ok = await channel._send_batch_message(
|
||||
"token",
|
||||
"group:conv123",
|
||||
"sampleMarkdown",
|
||||
{"text": "hello", "title": "Nanobot Reply"},
|
||||
)
|
||||
|
||||
assert ok is True
|
||||
call = channel._http.calls[0]
|
||||
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
assert call["json"]["openConversationId"] == "conv123"
|
||||
assert call["json"]["msgKey"] == "sampleMarkdown"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||
bus,
|
||||
)
|
||||
handler = NanobotDingTalkHandler(channel)
|
||||
|
||||
class _FakeChatbotMessage:
|
||||
text = None
|
||||
extensions = {"content": {"recognition": "voice transcript"}}
|
||||
sender_staff_id = "user1"
|
||||
sender_id = "fallback-user"
|
||||
sender_nick = "Alice"
|
||||
message_type = "audio"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(_data):
|
||||
return _FakeChatbotMessage()
|
||||
|
||||
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
|
||||
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||
|
||||
status, body = await handler.process(
|
||||
SimpleNamespace(
|
||||
data={
|
||||
"conversationType": "2",
|
||||
"conversationId": "conv123",
|
||||
"text": {"content": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*list(channel._background_tasks))
|
||||
msg = await bus.consume_inbound()
|
||||
|
||||
assert (status, body) == ("OK", "OK")
|
||||
assert msg.content == "voice transcript"
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_processes_file_message(monkeypatch) -> None:
|
||||
"""Test that file messages are handled and forwarded with downloaded path."""
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||
bus,
|
||||
)
|
||||
handler = NanobotDingTalkHandler(channel)
|
||||
|
||||
class _FakeFileChatbotMessage:
|
||||
text = None
|
||||
extensions = {}
|
||||
image_content = None
|
||||
rich_text_content = None
|
||||
sender_staff_id = "user1"
|
||||
sender_id = "fallback-user"
|
||||
sender_nick = "Alice"
|
||||
message_type = "file"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(_data):
|
||||
return _FakeFileChatbotMessage()
|
||||
|
||||
async def fake_download(download_code, filename, sender_id):
|
||||
return f"/tmp/nanobot_dingtalk/{sender_id}/{filename}"
|
||||
|
||||
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeFileChatbotMessage)
|
||||
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||
monkeypatch.setattr(channel, "_download_dingtalk_file", fake_download)
|
||||
|
||||
status, body = await handler.process(
|
||||
SimpleNamespace(
|
||||
data={
|
||||
"conversationType": "1",
|
||||
"content": {"downloadCode": "abc123", "fileName": "report.xlsx"},
|
||||
"text": {"content": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*list(channel._background_tasks))
|
||||
msg = await bus.consume_inbound()
|
||||
|
||||
assert (status, body) == ("OK", "OK")
|
||||
assert "[File]" in msg.content
|
||||
assert "/tmp/nanobot_dingtalk/user1/report.xlsx" in msg.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
|
||||
"""Test the two-step file download flow (get URL then download content)."""
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
# Mock access token
|
||||
async def fake_get_token():
|
||||
return "test-token"
|
||||
|
||||
monkeypatch.setattr(channel, "_get_access_token", fake_get_token)
|
||||
|
||||
# Mock HTTP: first POST returns downloadUrl, then GET returns file bytes
|
||||
file_content = b"fake file content"
|
||||
channel._http = _FakeHttp(responses=[
|
||||
_FakeResponse(200, {"downloadUrl": "https://example.com/tmpfile"}),
|
||||
_FakeResponse(200),
|
||||
])
|
||||
channel._http._responses[1].content = file_content
|
||||
|
||||
# Redirect media dir to tmp_path
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.paths.get_media_dir",
|
||||
lambda channel_name=None: tmp_path / channel_name if channel_name else tmp_path,
|
||||
)
|
||||
|
||||
result = await channel._download_dingtalk_file("code123", "test.xlsx", "user1")
|
||||
|
||||
assert result is not None
|
||||
assert result.endswith("test.xlsx")
|
||||
assert (tmp_path / "dingtalk" / "user1" / "test.xlsx").read_bytes() == file_content
|
||||
|
||||
# Verify API calls
|
||||
assert channel._http.calls[0]["method"] == "POST"
|
||||
assert "messageFiles/download" in channel._http.calls[0]["url"]
|
||||
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
|
||||
assert channel._http.calls[1]["method"] == "GET"
|
||||
@@ -0,0 +1,652 @@
|
||||
from email.message import EmailMessage
|
||||
from datetime import date
|
||||
import imaplib
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.email import EmailChannel
|
||||
from nanobot.channels.email import EmailConfig
|
||||
|
||||
|
||||
def _make_config(**overrides) -> EmailConfig:
|
||||
defaults = dict(
|
||||
enabled=True,
|
||||
consent_granted=True,
|
||||
imap_host="imap.example.com",
|
||||
imap_port=993,
|
||||
imap_username="bot@example.com",
|
||||
imap_password="secret",
|
||||
smtp_host="smtp.example.com",
|
||||
smtp_port=587,
|
||||
smtp_username="bot@example.com",
|
||||
smtp_password="secret",
|
||||
mark_seen=True,
|
||||
# Disable auth verification by default so existing tests are unaffected
|
||||
verify_dkim=False,
|
||||
verify_spf=False,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return EmailConfig(**defaults)
|
||||
|
||||
|
||||
def _make_raw_email(
|
||||
from_addr: str = "alice@example.com",
|
||||
subject: str = "Hello",
|
||||
body: str = "This is the body.",
|
||||
auth_results: str | None = None,
|
||||
) -> bytes:
|
||||
msg = EmailMessage()
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = "bot@example.com"
|
||||
msg["Subject"] = subject
|
||||
msg["Message-ID"] = "<m1@example.com>"
|
||||
if auth_results:
|
||||
msg["Authentication-Results"] = auth_results
|
||||
msg.set_content(body)
|
||||
return msg.as_bytes()
|
||||
|
||||
|
||||
def test_fetch_new_messages_parses_unseen_and_marks_seen(monkeypatch) -> None:
|
||||
raw = _make_raw_email(subject="Invoice", body="Please pay")
|
||||
|
||||
class FakeIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
fake = FakeIMAP()
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["sender"] == "alice@example.com"
|
||||
assert items[0]["subject"] == "Invoice"
|
||||
assert "Please pay" in items[0]["content"]
|
||||
assert fake.store_calls == [(b"1", "+FLAGS", "\\Seen")]
|
||||
|
||||
# Same UID should be deduped in-process.
|
||||
items_again = channel._fetch_new_messages()
|
||||
assert items_again == []
|
||||
|
||||
|
||||
def test_fetch_new_messages_retries_once_when_imap_connection_goes_stale(monkeypatch) -> None:
|
||||
raw = _make_raw_email(subject="Invoice", body="Please pay")
|
||||
fail_once = {"pending": True}
|
||||
|
||||
class FlakyIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
self.search_calls = 0
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
self.search_calls += 1
|
||||
if fail_once["pending"]:
|
||||
fail_once["pending"] = False
|
||||
raise imaplib.IMAP4.abort("socket error")
|
||||
return "OK", [b"1"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
fake_instances: list[FlakyIMAP] = []
|
||||
|
||||
def _factory(_host: str, _port: int):
|
||||
instance = FlakyIMAP()
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", _factory)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert len(fake_instances) == 2
|
||||
assert fake_instances[0].search_calls == 1
|
||||
assert fake_instances[1].search_calls == 1
|
||||
|
||||
|
||||
def test_fetch_new_messages_keeps_messages_collected_before_stale_retry(monkeypatch) -> None:
|
||||
raw_first = _make_raw_email(subject="First", body="First body")
|
||||
raw_second = _make_raw_email(subject="Second", body="Second body")
|
||||
mailbox_state = {
|
||||
b"1": {"uid": b"123", "raw": raw_first, "seen": False},
|
||||
b"2": {"uid": b"124", "raw": raw_second, "seen": False},
|
||||
}
|
||||
fail_once = {"pending": True}
|
||||
|
||||
class FlakyIMAP:
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"2"]
|
||||
|
||||
def search(self, *_args):
|
||||
unseen_ids = [imap_id for imap_id, item in mailbox_state.items() if not item["seen"]]
|
||||
return "OK", [b" ".join(unseen_ids)]
|
||||
|
||||
def fetch(self, imap_id: bytes, _parts: str):
|
||||
if imap_id == b"2" and fail_once["pending"]:
|
||||
fail_once["pending"] = False
|
||||
raise imaplib.IMAP4.abort("socket error")
|
||||
item = mailbox_state[imap_id]
|
||||
header = b"%s (UID %s BODY[] {200})" % (imap_id, item["uid"])
|
||||
return "OK", [(header, item["raw"]), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, _op: str, _flags: str):
|
||||
mailbox_state[imap_id]["seen"] = True
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: FlakyIMAP())
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert [item["subject"] for item in items] == ["First", "Second"]
|
||||
|
||||
|
||||
def test_fetch_new_messages_skips_missing_mailbox(monkeypatch) -> None:
|
||||
class MissingMailboxIMAP:
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
raise imaplib.IMAP4.error("Mailbox doesn't exist")
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.email.imaplib.IMAP4_SSL",
|
||||
lambda _h, _p: MissingMailboxIMAP(),
|
||||
)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
|
||||
assert channel._fetch_new_messages() == []
|
||||
|
||||
|
||||
def test_extract_text_body_falls_back_to_html() -> None:
|
||||
msg = EmailMessage()
|
||||
msg["From"] = "alice@example.com"
|
||||
msg["To"] = "bot@example.com"
|
||||
msg["Subject"] = "HTML only"
|
||||
msg.add_alternative("<p>Hello<br>world</p>", subtype="html")
|
||||
|
||||
text = EmailChannel._extract_text_body(msg)
|
||||
assert "Hello" in text
|
||||
assert "world" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_immediately_without_consent(monkeypatch) -> None:
|
||||
cfg = _make_config()
|
||||
cfg.consent_granted = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
called = {"fetch": False}
|
||||
|
||||
def _fake_fetch():
|
||||
called["fetch"] = True
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(channel, "_fetch_new_messages", _fake_fetch)
|
||||
await channel.start()
|
||||
assert channel.is_running is False
|
||||
assert called["fetch"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.timeout = timeout
|
||||
self.started_tls = False
|
||||
self.logged_in = False
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
self.started_tls = True
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
self.logged_in = True
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
channel._last_subject_by_chat["alice@example.com"] = "Invoice #42"
|
||||
channel._last_message_id_by_chat["alice@example.com"] = "<m1@example.com>"
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Acknowledged.",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_instances) == 1
|
||||
smtp = fake_instances[0]
|
||||
assert smtp.started_tls is True
|
||||
assert smtp.logged_in is True
|
||||
assert len(smtp.sent_messages) == 1
|
||||
sent = smtp.sent_messages[0]
|
||||
assert sent["Subject"] == "Re: Invoice #42"
|
||||
assert sent["To"] == "alice@example.com"
|
||||
assert sent["In-Reply-To"] == "<m1@example.com>"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# Mark alice as someone who sent us an email (making this a "reply")
|
||||
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
|
||||
|
||||
# Reply should be skipped (auto_reply_enabled=False)
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Should not send.",
|
||||
)
|
||||
)
|
||||
assert fake_instances == []
|
||||
|
||||
# Reply with force_send=True should be sent
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Force send.",
|
||||
metadata={"force_send": True},
|
||||
)
|
||||
)
|
||||
assert len(fake_instances) == 1
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# bob@example.com has never sent us an email (proactive send)
|
||||
# This should be sent even with auto_reply_enabled=False
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="bob@example.com",
|
||||
content="Hello, this is a proactive email.",
|
||||
)
|
||||
)
|
||||
assert len(fake_instances) == 1
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
sent = fake_instances[0].sent_messages[0]
|
||||
assert sent["To"] == "bob@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
called = {"smtp": False}
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
called["smtp"] = True
|
||||
return FakeSMTP(host, port, timeout=timeout)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.consent_granted = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Should not send.",
|
||||
metadata={"force_send": True},
|
||||
)
|
||||
)
|
||||
assert called["smtp"] is False
|
||||
|
||||
|
||||
def test_fetch_messages_between_dates_uses_imap_since_before_without_mark_seen(monkeypatch) -> None:
|
||||
raw = _make_raw_email(subject="Status", body="Yesterday update")
|
||||
|
||||
class FakeIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.search_args = None
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
self.search_args = _args
|
||||
return "OK", [b"5"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"5 (UID 999 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
fake = FakeIMAP()
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel.fetch_messages_between_dates(
|
||||
start_date=date(2026, 2, 6),
|
||||
end_date=date(2026, 2, 7),
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["subject"] == "Status"
|
||||
# search(None, "SINCE", "06-Feb-2026", "BEFORE", "07-Feb-2026")
|
||||
assert fake.search_args is not None
|
||||
assert fake.search_args[1:] == ("SINCE", "06-Feb-2026", "BEFORE", "07-Feb-2026")
|
||||
assert fake.store_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security: Anti-spoofing tests for Authentication-Results verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_fake_imap(raw: bytes):
|
||||
"""Return a FakeIMAP class pre-loaded with the given raw email."""
|
||||
class FakeIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"1 (UID 500 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
return FakeIMAP()
|
||||
|
||||
|
||||
def test_spoofed_email_rejected_when_verify_enabled(monkeypatch) -> None:
|
||||
"""An email without Authentication-Results should be rejected when verify_dkim=True."""
|
||||
raw = _make_raw_email(subject="Spoofed", body="Malicious payload")
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=True, verify_spf=True)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 0, "Spoofed email without auth headers should be rejected"
|
||||
|
||||
|
||||
def test_email_with_valid_auth_results_accepted(monkeypatch) -> None:
|
||||
"""An email with spf=pass and dkim=pass should be accepted."""
|
||||
raw = _make_raw_email(
|
||||
subject="Legit",
|
||||
body="Hello from verified sender",
|
||||
auth_results="mx.example.com; spf=pass smtp.mailfrom=alice@example.com; dkim=pass header.d=example.com",
|
||||
)
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=True, verify_spf=True)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["sender"] == "alice@example.com"
|
||||
assert items[0]["subject"] == "Legit"
|
||||
|
||||
|
||||
def test_email_with_partial_auth_rejected(monkeypatch) -> None:
|
||||
"""An email with only spf=pass but no dkim=pass should be rejected when verify_dkim=True."""
|
||||
raw = _make_raw_email(
|
||||
subject="Partial",
|
||||
body="Only SPF passes",
|
||||
auth_results="mx.example.com; spf=pass smtp.mailfrom=alice@example.com; dkim=fail",
|
||||
)
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=True, verify_spf=True)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 0, "Email with dkim=fail should be rejected"
|
||||
|
||||
|
||||
def test_backward_compat_verify_disabled(monkeypatch) -> None:
|
||||
"""When verify_dkim=False and verify_spf=False, emails without auth headers are accepted."""
|
||||
raw = _make_raw_email(subject="NoAuth", body="No auth headers present")
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=False, verify_spf=False)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1, "With verification disabled, emails should be accepted as before"
|
||||
|
||||
|
||||
def test_email_content_tagged_with_email_context(monkeypatch) -> None:
|
||||
"""Email content should be prefixed with [EMAIL-CONTEXT] for LLM isolation."""
|
||||
raw = _make_raw_email(subject="Tagged", body="Check the tag")
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=False, verify_spf=False)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["content"].startswith("[EMAIL-CONTEXT]"), (
|
||||
"Email content must be tagged with [EMAIL-CONTEXT]"
|
||||
)
|
||||
|
||||
|
||||
def test_check_authentication_results_method() -> None:
|
||||
"""Unit test for the _check_authentication_results static method."""
|
||||
from email.parser import BytesParser
|
||||
from email import policy
|
||||
|
||||
# No Authentication-Results header
|
||||
msg_no_auth = EmailMessage()
|
||||
msg_no_auth["From"] = "alice@example.com"
|
||||
msg_no_auth.set_content("test")
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(msg_no_auth.as_bytes())
|
||||
spf, dkim = EmailChannel._check_authentication_results(parsed)
|
||||
assert spf is False
|
||||
assert dkim is False
|
||||
|
||||
# Both pass
|
||||
msg_both = EmailMessage()
|
||||
msg_both["From"] = "alice@example.com"
|
||||
msg_both["Authentication-Results"] = (
|
||||
"mx.google.com; spf=pass smtp.mailfrom=example.com; dkim=pass header.d=example.com"
|
||||
)
|
||||
msg_both.set_content("test")
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(msg_both.as_bytes())
|
||||
spf, dkim = EmailChannel._check_authentication_results(parsed)
|
||||
assert spf is True
|
||||
assert dkim is True
|
||||
|
||||
# SPF pass, DKIM fail
|
||||
msg_spf_only = EmailMessage()
|
||||
msg_spf_only["From"] = "alice@example.com"
|
||||
msg_spf_only["Authentication-Results"] = (
|
||||
"mx.google.com; spf=pass smtp.mailfrom=example.com; dkim=fail"
|
||||
)
|
||||
msg_spf_only.set_content("test")
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(msg_spf_only.as_bytes())
|
||||
spf, dkim = EmailChannel._check_authentication_results(parsed)
|
||||
assert spf is True
|
||||
assert dkim is False
|
||||
|
||||
# DKIM pass, SPF fail
|
||||
msg_dkim_only = EmailMessage()
|
||||
msg_dkim_only["From"] = "alice@example.com"
|
||||
msg_dkim_only["Authentication-Results"] = (
|
||||
"mx.google.com; spf=fail smtp.mailfrom=example.com; dkim=pass header.d=example.com"
|
||||
)
|
||||
msg_dkim_only.set_content("test")
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(msg_dkim_only.as_bytes())
|
||||
spf, dkim = EmailChannel._check_authentication_results(parsed)
|
||||
assert spf is False
|
||||
assert dkim is True
|
||||
@@ -0,0 +1,68 @@
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
import pytest
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
def test_parse_md_table_strips_markdown_formatting_in_headers_and_cells() -> None:
|
||||
table = FeishuChannel._parse_md_table(
|
||||
"""
|
||||
| **Name** | __Status__ | *Notes* | ~~State~~ |
|
||||
| --- | --- | --- | --- |
|
||||
| **Alice** | __Ready__ | *Fast* | ~~Old~~ |
|
||||
"""
|
||||
)
|
||||
|
||||
assert table is not None
|
||||
assert [col["display_name"] for col in table["columns"]] == [
|
||||
"Name",
|
||||
"Status",
|
||||
"Notes",
|
||||
"State",
|
||||
]
|
||||
assert table["rows"] == [
|
||||
{"c0": "Alice", "c1": "Ready", "c2": "Fast", "c3": "Old"}
|
||||
]
|
||||
|
||||
|
||||
def test_split_headings_strips_embedded_markdown_before_bolding() -> None:
|
||||
channel = FeishuChannel.__new__(FeishuChannel)
|
||||
|
||||
elements = channel._split_headings("# **Important** *status* ~~update~~")
|
||||
|
||||
assert elements == [
|
||||
{
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": "**Important status update**",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_split_headings_keeps_markdown_body_and_code_blocks_intact() -> None:
|
||||
channel = FeishuChannel.__new__(FeishuChannel)
|
||||
|
||||
elements = channel._split_headings(
|
||||
"# **Heading**\n\nBody with **bold** text.\n\n```python\nprint('hi')\n```"
|
||||
)
|
||||
|
||||
assert elements[0] == {
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": "**Heading**",
|
||||
},
|
||||
}
|
||||
assert elements[1]["tag"] == "markdown"
|
||||
assert "Body with **bold** text." in elements[1]["content"]
|
||||
assert "```python\nprint('hi')\n```" in elements[1]["content"]
|
||||
@@ -0,0 +1,76 @@
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
import pytest
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
|
||||
|
||||
|
||||
def test_extract_post_content_supports_post_wrapper_shape() -> None:
|
||||
payload = {
|
||||
"post": {
|
||||
"zh_cn": {
|
||||
"title": "日报",
|
||||
"content": [
|
||||
[
|
||||
{"tag": "text", "text": "完成"},
|
||||
{"tag": "img", "image_key": "img_1"},
|
||||
]
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text, image_keys = _extract_post_content(payload)
|
||||
|
||||
assert text == "日报 完成"
|
||||
assert image_keys == ["img_1"]
|
||||
|
||||
|
||||
def test_extract_post_content_keeps_direct_shape_behavior() -> None:
|
||||
payload = {
|
||||
"title": "Daily",
|
||||
"content": [
|
||||
[
|
||||
{"tag": "text", "text": "report"},
|
||||
{"tag": "img", "image_key": "img_a"},
|
||||
{"tag": "img", "image_key": "img_b"},
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
text, image_keys = _extract_post_content(payload)
|
||||
|
||||
assert text == "Daily report"
|
||||
assert image_keys == ["img_a", "img_b"]
|
||||
|
||||
|
||||
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
|
||||
class Builder:
|
||||
pass
|
||||
|
||||
builder = Builder()
|
||||
same = FeishuChannel._register_optional_event(builder, "missing", object())
|
||||
assert same is builder
|
||||
|
||||
|
||||
def test_register_optional_event_calls_supported_method() -> None:
|
||||
called = []
|
||||
|
||||
class Builder:
|
||||
def register_event(self, handler):
|
||||
called.append(handler)
|
||||
return self
|
||||
|
||||
builder = Builder()
|
||||
handler = object()
|
||||
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
|
||||
|
||||
assert same is builder
|
||||
assert called == [handler]
|
||||
@@ -0,0 +1,445 @@
|
||||
"""Tests for Feishu message reply (quote) feature."""
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel:
|
||||
config = FeishuConfig(
|
||||
enabled=True,
|
||||
app_id="cli_test",
|
||||
app_secret="secret",
|
||||
allow_from=["*"],
|
||||
reply_to_message=reply_to_message,
|
||||
)
|
||||
channel = FeishuChannel(config, MessageBus())
|
||||
channel._client = MagicMock()
|
||||
# _loop is only used by the WebSocket thread bridge; not needed for unit tests
|
||||
channel._loop = None
|
||||
return channel
|
||||
|
||||
|
||||
def _make_feishu_event(
|
||||
*,
|
||||
message_id: str = "om_001",
|
||||
chat_id: str = "oc_abc",
|
||||
chat_type: str = "p2p",
|
||||
msg_type: str = "text",
|
||||
content: str = '{"text": "hello"}',
|
||||
sender_open_id: str = "ou_alice",
|
||||
parent_id: str | None = None,
|
||||
root_id: str | None = None,
|
||||
):
|
||||
message = SimpleNamespace(
|
||||
message_id=message_id,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
message_type=msg_type,
|
||||
content=content,
|
||||
parent_id=parent_id,
|
||||
root_id=root_id,
|
||||
mentions=[],
|
||||
)
|
||||
sender = SimpleNamespace(
|
||||
sender_type="user",
|
||||
sender_id=SimpleNamespace(open_id=sender_open_id),
|
||||
)
|
||||
return SimpleNamespace(event=SimpleNamespace(message=message, sender=sender))
|
||||
|
||||
|
||||
def _make_get_message_response(text: str, msg_type: str = "text", success: bool = True):
|
||||
"""Build a fake im.v1.message.get response object."""
|
||||
body = SimpleNamespace(content=json.dumps({"text": text}))
|
||||
item = SimpleNamespace(msg_type=msg_type, body=body)
|
||||
data = SimpleNamespace(items=[item])
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = success
|
||||
resp.data = data
|
||||
resp.code = 0
|
||||
resp.msg = "ok"
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_feishu_config_reply_to_message_defaults_false() -> None:
|
||||
assert FeishuConfig().reply_to_message is False
|
||||
|
||||
|
||||
def test_feishu_config_reply_to_message_can_be_enabled() -> None:
|
||||
config = FeishuConfig(reply_to_message=True)
|
||||
assert config.reply_to_message is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_message_content_sync tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_message_content_sync_returns_reply_prefix() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._client.im.v1.message.get.return_value = _make_get_message_response("what time is it?")
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result == "[Reply to: what time is it?]"
|
||||
|
||||
|
||||
def test_get_message_content_sync_truncates_long_text() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
long_text = "x" * (FeishuChannel._REPLY_CONTEXT_MAX_LEN + 50)
|
||||
channel._client.im.v1.message.get.return_value = _make_get_message_response(long_text)
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result is not None
|
||||
assert result.endswith("...]")
|
||||
inner = result[len("[Reply to: ") : -1]
|
||||
assert len(inner) == FeishuChannel._REPLY_CONTEXT_MAX_LEN + len("...")
|
||||
|
||||
|
||||
def test_get_message_content_sync_returns_none_on_api_failure() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 230002
|
||||
resp.msg = "bot not in group"
|
||||
channel._client.im.v1.message.get.return_value = resp
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_message_content_sync_returns_none_for_non_text_type() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
body = SimpleNamespace(content=json.dumps({"image_key": "img_1"}))
|
||||
item = SimpleNamespace(msg_type="image", body=body)
|
||||
data = SimpleNamespace(items=[item])
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = True
|
||||
resp.data = data
|
||||
channel._client.im.v1.message.get.return_value = resp
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_message_content_sync_returns_none_when_empty_text() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._client.im.v1.message.get.return_value = _make_get_message_response(" ")
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reply_message_sync tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_reply_message_sync_returns_true_on_success() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = resp
|
||||
|
||||
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
|
||||
|
||||
assert ok is True
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
|
||||
|
||||
def test_reply_message_sync_returns_false_on_api_error() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 400
|
||||
resp.msg = "bad request"
|
||||
resp.get_log_id.return_value = "log_x"
|
||||
channel._client.im.v1.message.reply.return_value = resp
|
||||
|
||||
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
|
||||
|
||||
assert ok is False
|
||||
|
||||
|
||||
def test_reply_message_sync_returns_false_on_exception() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._client.im.v1.message.reply.side_effect = RuntimeError("network error")
|
||||
|
||||
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
|
||||
|
||||
assert ok is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("filename", "expected_msg_type"),
|
||||
[
|
||||
("voice.opus", "audio"),
|
||||
("clip.mp4", "video"),
|
||||
("report.pdf", "file"),
|
||||
],
|
||||
)
|
||||
async def test_send_uses_expected_feishu_msg_type_for_uploaded_files(
|
||||
tmp_path: Path, filename: str, expected_msg_type: str
|
||||
) -> None:
|
||||
channel = _make_feishu_channel()
|
||||
file_path = tmp_path / filename
|
||||
file_path.write_bytes(b"demo")
|
||||
|
||||
send_calls: list[tuple[str, str, str, str]] = []
|
||||
|
||||
def _record_send(receive_id_type: str, receive_id: str, msg_type: str, content: str) -> None:
|
||||
send_calls.append((receive_id_type, receive_id, msg_type, content))
|
||||
|
||||
with patch.object(channel, "_upload_file_sync", return_value="file-key"), patch.object(
|
||||
channel, "_send_message_sync", side_effect=_record_send
|
||||
):
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_test",
|
||||
content="",
|
||||
media=[str(file_path)],
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(send_calls) == 1
|
||||
receive_id_type, receive_id, msg_type, content = send_calls[0]
|
||||
assert receive_id_type == "chat_id"
|
||||
assert receive_id == "oc_test"
|
||||
assert msg_type == expected_msg_type
|
||||
assert json.loads(content) == {"file_key": "file-key"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send() — reply routing tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_reply_api_when_configured() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
channel._client.im.v1.message.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_create_api_when_reply_disabled() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=False)
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
channel._client.im.v1.message.reply.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_create_api_when_no_message_id() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
channel._client.im.v1.message.reply.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_reply_for_progress_messages() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="thinking...",
|
||||
metadata={"message_id": "om_001", "_progress": True},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
channel._client.im.v1.message.reply.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_fallback_to_create_when_reply_fails() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = False
|
||||
reply_resp.code = 400
|
||||
reply_resp.msg = "error"
|
||||
reply_resp.get_log_id.return_value = "log_x"
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
# reply attempted first, then falls back to create
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _on_message — parent_id / root_id metadata tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_captures_parent_and_root_id_in_metadata() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._processed_message_ids.clear()
|
||||
channel._client.im.v1.message.react.return_value = MagicMock(success=lambda: True)
|
||||
|
||||
captured = []
|
||||
|
||||
async def _capture(**kwargs):
|
||||
captured.append(kwargs)
|
||||
|
||||
channel._handle_message = _capture
|
||||
|
||||
with patch.object(channel, "_add_reaction", return_value=None):
|
||||
await channel._on_message(
|
||||
_make_feishu_event(
|
||||
parent_id="om_parent",
|
||||
root_id="om_root",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(captured) == 1
|
||||
meta = captured[0]["metadata"]
|
||||
assert meta["parent_id"] == "om_parent"
|
||||
assert meta["root_id"] == "om_root"
|
||||
assert meta["message_id"] == "om_001"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_parent_and_root_id_none_when_absent() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._processed_message_ids.clear()
|
||||
|
||||
captured = []
|
||||
|
||||
async def _capture(**kwargs):
|
||||
captured.append(kwargs)
|
||||
|
||||
channel._handle_message = _capture
|
||||
|
||||
with patch.object(channel, "_add_reaction", return_value=None):
|
||||
await channel._on_message(_make_feishu_event())
|
||||
|
||||
assert len(captured) == 1
|
||||
meta = captured[0]["metadata"]
|
||||
assert meta["parent_id"] is None
|
||||
assert meta["root_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_prepends_reply_context_when_parent_id_present() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._processed_message_ids.clear()
|
||||
channel._client.im.v1.message.get.return_value = _make_get_message_response("original question")
|
||||
|
||||
captured = []
|
||||
|
||||
async def _capture(**kwargs):
|
||||
captured.append(kwargs)
|
||||
|
||||
channel._handle_message = _capture
|
||||
|
||||
with patch.object(channel, "_add_reaction", return_value=None):
|
||||
await channel._on_message(
|
||||
_make_feishu_event(
|
||||
content='{"text": "my answer"}',
|
||||
parent_id="om_parent",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(captured) == 1
|
||||
content = captured[0]["content"]
|
||||
assert content.startswith("[Reply to: original question]")
|
||||
assert "my answer" in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_no_extra_api_call_when_no_parent_id() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._processed_message_ids.clear()
|
||||
|
||||
captured = []
|
||||
|
||||
async def _capture(**kwargs):
|
||||
captured.append(kwargs)
|
||||
|
||||
channel._handle_message = _capture
|
||||
|
||||
with patch.object(channel, "_add_reaction", return_value=None):
|
||||
await channel._on_message(_make_feishu_event())
|
||||
|
||||
channel._client.im.v1.message.get.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
@@ -0,0 +1,258 @@
|
||||
"""Tests for Feishu streaming (send_delta) via CardKit streaming API."""
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig, _FeishuStreamBuf
|
||||
|
||||
|
||||
def _make_channel(streaming: bool = True) -> FeishuChannel:
|
||||
config = FeishuConfig(
|
||||
enabled=True,
|
||||
app_id="cli_test",
|
||||
app_secret="secret",
|
||||
allow_from=["*"],
|
||||
streaming=streaming,
|
||||
)
|
||||
ch = FeishuChannel(config, MessageBus())
|
||||
ch._client = MagicMock()
|
||||
ch._loop = None
|
||||
return ch
|
||||
|
||||
|
||||
def _mock_create_card_response(card_id: str = "card_stream_001"):
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = True
|
||||
resp.data = SimpleNamespace(card_id=card_id)
|
||||
return resp
|
||||
|
||||
|
||||
def _mock_send_response(message_id: str = "om_stream_001"):
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = True
|
||||
resp.data = SimpleNamespace(message_id=message_id)
|
||||
return resp
|
||||
|
||||
|
||||
def _mock_content_response(success: bool = True):
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = success
|
||||
resp.code = 0 if success else 99999
|
||||
resp.msg = "ok" if success else "error"
|
||||
return resp
|
||||
|
||||
|
||||
class TestFeishuStreamingConfig:
|
||||
def test_streaming_default_true(self):
|
||||
assert FeishuConfig().streaming is True
|
||||
|
||||
def test_supports_streaming_when_enabled(self):
|
||||
ch = _make_channel(streaming=True)
|
||||
assert ch.supports_streaming is True
|
||||
|
||||
def test_supports_streaming_disabled(self):
|
||||
ch = _make_channel(streaming=False)
|
||||
assert ch.supports_streaming is False
|
||||
|
||||
|
||||
class TestCreateStreamingCard:
|
||||
def test_returns_card_id_on_success(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.create.return_value = _mock_create_card_response("card_123")
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response()
|
||||
result = ch._create_streaming_card_sync("chat_id", "oc_chat1")
|
||||
assert result == "card_123"
|
||||
ch._client.cardkit.v1.card.create.assert_called_once()
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
def test_returns_none_on_failure(self):
|
||||
ch = _make_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 99999
|
||||
resp.msg = "error"
|
||||
ch._client.cardkit.v1.card.create.return_value = resp
|
||||
assert ch._create_streaming_card_sync("chat_id", "oc_chat1") is None
|
||||
|
||||
def test_returns_none_on_exception(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.create.side_effect = RuntimeError("network")
|
||||
assert ch._create_streaming_card_sync("chat_id", "oc_chat1") is None
|
||||
|
||||
def test_returns_none_when_card_send_fails(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.create.return_value = _mock_create_card_response("card_123")
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 99999
|
||||
resp.msg = "error"
|
||||
resp.get_log_id.return_value = "log1"
|
||||
ch._client.im.v1.message.create.return_value = resp
|
||||
assert ch._create_streaming_card_sync("chat_id", "oc_chat1") is None
|
||||
|
||||
|
||||
class TestCloseStreamingMode:
|
||||
def test_returns_true_on_success(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response(True)
|
||||
assert ch._close_streaming_mode_sync("card_1", 10) is True
|
||||
|
||||
def test_returns_false_on_failure(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response(False)
|
||||
assert ch._close_streaming_mode_sync("card_1", 10) is False
|
||||
|
||||
def test_returns_false_on_exception(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.settings.side_effect = RuntimeError("err")
|
||||
assert ch._close_streaming_mode_sync("card_1", 10) is False
|
||||
|
||||
|
||||
class TestStreamUpdateText:
|
||||
def test_returns_true_on_success(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(True)
|
||||
assert ch._stream_update_text_sync("card_1", "hello", 1) is True
|
||||
|
||||
def test_returns_false_on_failure(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(False)
|
||||
assert ch._stream_update_text_sync("card_1", "hello", 1) is False
|
||||
|
||||
def test_returns_false_on_exception(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card_element.content.side_effect = RuntimeError("err")
|
||||
assert ch._stream_update_text_sync("card_1", "hello", 1) is False
|
||||
|
||||
|
||||
class TestSendDelta:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_delta_creates_card_and_sends(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.create.return_value = _mock_create_card_response("card_new")
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response("om_new")
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", "Hello ")
|
||||
|
||||
assert "oc_chat1" in ch._stream_bufs
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert buf.text == "Hello "
|
||||
assert buf.card_id == "card_new"
|
||||
assert buf.sequence == 1
|
||||
ch._client.cardkit.v1.card.create.assert_called_once()
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_delta_within_interval_skips_update(self):
|
||||
ch = _make_channel()
|
||||
buf = _FeishuStreamBuf(text="Hello ", card_id="card_1", sequence=1, last_edit=time.monotonic())
|
||||
ch._stream_bufs["oc_chat1"] = buf
|
||||
|
||||
await ch.send_delta("oc_chat1", "world")
|
||||
|
||||
assert buf.text == "Hello world"
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delta_after_interval_updates_text(self):
|
||||
ch = _make_channel()
|
||||
buf = _FeishuStreamBuf(text="Hello ", card_id="card_1", sequence=1, last_edit=time.monotonic() - 1.0)
|
||||
ch._stream_bufs["oc_chat1"] = buf
|
||||
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
await ch.send_delta("oc_chat1", "world")
|
||||
|
||||
assert buf.text == "Hello world"
|
||||
assert buf.sequence == 2
|
||||
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_sends_final_update(self):
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Final content", card_id="card_1", sequence=3, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||
ch._client.cardkit.v1.card.settings.assert_called_once()
|
||||
settings_call = ch._client.cardkit.v1.card.settings.call_args[0][0]
|
||||
assert settings_call.body.sequence == 5 # after final content seq 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_fallback_when_no_card_id(self):
|
||||
"""If card creation failed, stream_end falls back to a plain card message."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Fallback content", card_id=None, sequence=0, last_edit=0.0,
|
||||
)
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response("om_fb")
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_without_buf_is_noop(self):
|
||||
ch = _make_channel()
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_delta_skips_send(self):
|
||||
ch = _make_channel()
|
||||
await ch.send_delta("oc_chat1", " ")
|
||||
|
||||
assert "oc_chat1" in ch._stream_bufs
|
||||
ch._client.cardkit.v1.card.create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_returns_early(self):
|
||||
ch = _make_channel()
|
||||
ch._client = None
|
||||
await ch.send_delta("oc_chat1", "text")
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequence_increments_correctly(self):
|
||||
ch = _make_channel()
|
||||
buf = _FeishuStreamBuf(text="a", card_id="card_1", sequence=5, last_edit=0.0)
|
||||
ch._stream_bufs["oc_chat1"] = buf
|
||||
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
await ch.send_delta("oc_chat1", "b")
|
||||
assert buf.sequence == 6
|
||||
|
||||
buf.last_edit = 0.0 # reset to bypass throttle
|
||||
await ch.send_delta("oc_chat1", "c")
|
||||
assert buf.sequence == 7
|
||||
|
||||
|
||||
class TestSendMessageReturnsId:
|
||||
def test_returns_message_id_on_success(self):
|
||||
ch = _make_channel()
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response("om_abc")
|
||||
result = ch._send_message_sync("chat_id", "oc_chat1", "text", '{"text":"hi"}')
|
||||
assert result == "om_abc"
|
||||
|
||||
def test_returns_none_on_failure(self):
|
||||
ch = _make_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 99999
|
||||
resp.msg = "error"
|
||||
resp.get_log_id.return_value = "log1"
|
||||
ch._client.im.v1.message.create.return_value = resp
|
||||
result = ch._send_message_sync("chat_id", "oc_chat1", "text", '{"text":"hi"}')
|
||||
assert result is None
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Tests for FeishuChannel._split_elements_by_table_limit.
|
||||
|
||||
Feishu cards reject messages that contain more than one table element
|
||||
(API error 11310: card table number over limit). The helper splits a flat
|
||||
list of card elements into groups so that each group contains at most one
|
||||
table, allowing nanobot to send multiple cards instead of failing.
|
||||
"""
|
||||
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
import pytest
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
def _md(text: str) -> dict:
|
||||
return {"tag": "markdown", "content": text}
|
||||
|
||||
|
||||
def _table() -> dict:
|
||||
return {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "v"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
|
||||
|
||||
split = FeishuChannel._split_elements_by_table_limit
|
||||
|
||||
|
||||
def test_empty_list_returns_single_empty_group() -> None:
|
||||
assert split([]) == [[]]
|
||||
|
||||
|
||||
def test_no_tables_returns_single_group() -> None:
|
||||
els = [_md("hello"), _md("world")]
|
||||
result = split(els)
|
||||
assert result == [els]
|
||||
|
||||
|
||||
def test_single_table_stays_in_one_group() -> None:
|
||||
els = [_md("intro"), _table(), _md("outro")]
|
||||
result = split(els)
|
||||
assert len(result) == 1
|
||||
assert result[0] == els
|
||||
|
||||
|
||||
def test_two_tables_split_into_two_groups() -> None:
|
||||
# Use different row values so the two tables are not equal
|
||||
t1 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "table-one"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
t2 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
|
||||
"rows": [{"c0": "table-two"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
els = [_md("before"), t1, _md("between"), t2, _md("after")]
|
||||
result = split(els)
|
||||
assert len(result) == 2
|
||||
# First group: text before table-1 + table-1
|
||||
assert t1 in result[0]
|
||||
assert t2 not in result[0]
|
||||
# Second group: text between tables + table-2 + text after
|
||||
assert t2 in result[1]
|
||||
assert t1 not in result[1]
|
||||
|
||||
|
||||
def test_three_tables_split_into_three_groups() -> None:
|
||||
tables = [
|
||||
{"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
|
||||
for i in range(3)
|
||||
]
|
||||
els = tables[:]
|
||||
result = split(els)
|
||||
assert len(result) == 3
|
||||
for i, group in enumerate(result):
|
||||
assert tables[i] in group
|
||||
|
||||
|
||||
def test_leading_markdown_stays_with_first_table() -> None:
|
||||
intro = _md("intro")
|
||||
t = _table()
|
||||
result = split([intro, t])
|
||||
assert len(result) == 1
|
||||
assert result[0] == [intro, t]
|
||||
|
||||
|
||||
def test_trailing_markdown_after_second_table() -> None:
|
||||
t1, t2 = _table(), _table()
|
||||
tail = _md("end")
|
||||
result = split([t1, t2, tail])
|
||||
assert len(result) == 2
|
||||
assert result[1] == [t2, tail]
|
||||
|
||||
|
||||
def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
|
||||
head = _md("head")
|
||||
t1, t2 = _table(), _table()
|
||||
result = split([head, t1, t2])
|
||||
# head + t1 in group 0; t2 in group 1
|
||||
assert result[0] == [head, t1]
|
||||
assert result[1] == [t2]
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Tests for FeishuChannel tool hint code block formatting."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feishu_channel():
|
||||
"""Create a FeishuChannel with mocked client."""
|
||||
config = MagicMock()
|
||||
config.app_id = "test_app_id"
|
||||
config.app_secret = "test_app_secret"
|
||||
config.encrypt_key = None
|
||||
config.verification_token = None
|
||||
bus = MagicMock()
|
||||
channel = FeishuChannel(config, bus)
|
||||
channel._client = MagicMock() # Simulate initialized client
|
||||
return channel
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_sends_code_message(mock_feishu_channel):
|
||||
"""Tool hint messages should be sent as interactive cards with code blocks."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='web_search("test query")',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
# Verify interactive message with card was sent
|
||||
assert mock_send.call_count == 1
|
||||
call_args = mock_send.call_args[0]
|
||||
receive_id_type, receive_id, msg_type, content = call_args
|
||||
|
||||
assert receive_id_type == "chat_id"
|
||||
assert receive_id == "oc_123456"
|
||||
assert msg_type == "interactive"
|
||||
|
||||
# Parse content to verify card structure
|
||||
card = json.loads(content)
|
||||
assert card["config"]["wide_screen_mode"] is True
|
||||
assert len(card["elements"]) == 1
|
||||
assert card["elements"][0]["tag"] == "markdown"
|
||||
# Check that code block is properly formatted with language hint
|
||||
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"test query\")\n```"
|
||||
assert card["elements"][0]["content"] == expected_md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_empty_content_does_not_send(mock_feishu_channel):
|
||||
"""Empty tool hint messages should not be sent."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content=" ", # whitespace only
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
# Should not send any message
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
|
||||
"""Regular messages without _tool_hint should use normal formatting."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content="Hello, world!",
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
# Should send as text message (detected format)
|
||||
assert mock_send.call_count == 1
|
||||
call_args = mock_send.call_args[0]
|
||||
_, _, msg_type, content = call_args
|
||||
assert msg_type == "text"
|
||||
assert json.loads(content) == {"text": "Hello, world!"}
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
|
||||
"""Multiple tool calls should be displayed each on its own line in a code block."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='web_search("query"), read_file("/path/to/file")',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
call_args = mock_send.call_args[0]
|
||||
msg_type = call_args[2]
|
||||
content = json.loads(call_args[3])
|
||||
assert msg_type == "interactive"
|
||||
# Each tool call should be on its own line
|
||||
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"query\"),\nread_file(\"/path/to/file\")\n```"
|
||||
assert content["elements"][0]["content"] == expected_md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
|
||||
"""Commas inside a single tool argument must not be split onto a new line."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='web_search("foo, bar"), read_file("/path/to/file")',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
expected_md = (
|
||||
"**Tool Calls**\n\n```text\n"
|
||||
"web_search(\"foo, bar\"),\n"
|
||||
"read_file(\"/path/to/file\")\n```"
|
||||
)
|
||||
assert content["elements"][0]["content"] == expected_md
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,172 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional QQ dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import qq
|
||||
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
|
||||
if not QQ_AVAILABLE:
|
||||
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import QQChannel, QQConfig
|
||||
|
||||
|
||||
class _FakeApi:
|
||||
def __init__(self) -> None:
|
||||
self.c2c_calls: list[dict] = []
|
||||
self.group_calls: list[dict] = []
|
||||
|
||||
async def post_c2c_message(self, **kwargs) -> None:
|
||||
self.c2c_calls.append(kwargs)
|
||||
|
||||
async def post_group_message(self, **kwargs) -> None:
|
||||
self.group_calls.append(kwargs)
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self) -> None:
|
||||
self.api = _FakeApi()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_group_message_routes_to_group_chat_id() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus())
|
||||
|
||||
data = SimpleNamespace(
|
||||
id="msg1",
|
||||
content="hello",
|
||||
group_openid="group123",
|
||||
author=SimpleNamespace(member_openid="user1"),
|
||||
attachments=[],
|
||||
)
|
||||
|
||||
await channel._on_message(data, is_group=True)
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
channel._chat_type_cache["group123"] = "group"
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="group123",
|
||||
content="hello",
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(channel._client.api.group_calls) == 1
|
||||
call = channel._client.api.group_calls[0]
|
||||
assert call == {
|
||||
"group_openid": "group123",
|
||||
"msg_type": 0,
|
||||
"content": "hello",
|
||||
"msg_id": "msg1",
|
||||
"msg_seq": 2,
|
||||
}
|
||||
assert not channel._client.api.c2c_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user123",
|
||||
content="hello",
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(channel._client.api.c2c_calls) == 1
|
||||
call = channel._client.api.c2c_calls[0]
|
||||
assert call == {
|
||||
"openid": "user123",
|
||||
"msg_type": 0,
|
||||
"content": "hello",
|
||||
"msg_id": "msg1",
|
||||
"msg_seq": 2,
|
||||
}
|
||||
assert not channel._client.api.group_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_group_message_uses_markdown_when_configured() -> None:
|
||||
channel = QQChannel(
|
||||
QQConfig(app_id="app", secret="secret", allow_from=["*"], msg_format="markdown"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._client = _FakeClient()
|
||||
channel._chat_type_cache["group123"] = "group"
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="group123",
|
||||
content="**hello**",
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(channel._client.api.group_calls) == 1
|
||||
call = channel._client.api.group_calls[0]
|
||||
assert call == {
|
||||
"group_openid": "group123",
|
||||
"msg_type": 2,
|
||||
"markdown": {"content": "**hello**"},
|
||||
"msg_id": "msg1",
|
||||
"msg_seq": 2,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_media_bytes_local_path() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp_path = f.name
|
||||
|
||||
data, filename = await channel._read_media_bytes(tmp_path)
|
||||
assert data == b"\x89PNG\r\n"
|
||||
assert filename == Path(tmp_path).name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_media_bytes_file_uri() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"JFIF")
|
||||
tmp_path = f.name
|
||||
|
||||
data, filename = await channel._read_media_bytes(f"file://{tmp_path}")
|
||||
assert data == b"JFIF"
|
||||
assert filename == Path(tmp_path).name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_media_bytes_missing_file() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
|
||||
data, filename = await channel._read_media_bytes("/nonexistent/path/image.png")
|
||||
assert data is None
|
||||
assert filename is None
|
||||
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional Slack dependencies before running tests
|
||||
try:
|
||||
import slack_sdk # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("Slack dependencies not installed (slack-sdk)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
from nanobot.channels.slack import SlackConfig
|
||||
|
||||
|
||||
class _FakeAsyncWebClient:
|
||||
def __init__(self) -> None:
|
||||
self.chat_post_calls: list[dict[str, object | None]] = []
|
||||
self.file_upload_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_add_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_remove_calls: list[dict[str, object | None]] = []
|
||||
|
||||
async def chat_postMessage(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
text: str,
|
||||
thread_ts: str | None = None,
|
||||
) -> None:
|
||||
self.chat_post_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"text": text,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
)
|
||||
|
||||
async def files_upload_v2(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
file: str,
|
||||
thread_ts: str | None = None,
|
||||
) -> None:
|
||||
self.file_upload_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"file": file,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
)
|
||||
|
||||
async def reactions_add(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
name: str,
|
||||
timestamp: str,
|
||||
) -> None:
|
||||
self.reactions_add_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"name": name,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
)
|
||||
|
||||
async def reactions_remove(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
name: str,
|
||||
timestamp: str,
|
||||
) -> None:
|
||||
self.reactions_remove_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"name": name,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_thread_for_channel_messages() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
content="hello",
|
||||
media=["/tmp/demo.txt"],
|
||||
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "channel"}},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 1
|
||||
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||
assert fake_web.chat_post_calls[0]["thread_ts"] == "1700000000.000100"
|
||||
assert len(fake_web.file_upload_calls) == 1
|
||||
assert fake_web.file_upload_calls[0]["thread_ts"] == "1700000000.000100"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_omits_thread_for_dm_messages() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="D123",
|
||||
content="hello",
|
||||
media=["/tmp/demo.txt"],
|
||||
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "im"}},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 1
|
||||
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||
assert fake_web.chat_post_calls[0]["thread_ts"] is None
|
||||
assert len(fake_web.file_upload_calls) == 1
|
||||
assert fake_web.file_upload_calls[0]["thread_ts"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_updates_reaction_when_final_response_sent() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
content="done",
|
||||
metadata={
|
||||
"slack": {"event": {"ts": "1700000000.000100"}, "channel_type": "channel"},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.reactions_remove_calls == [
|
||||
{"channel": "C123", "name": "eyes", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
assert fake_web.reactions_add_calls == [
|
||||
{"channel": "C123", "name": "white_check_mark", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
@@ -0,0 +1,966 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional Telegram dependencies before running tests
|
||||
try:
|
||||
import telegram # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("Telegram dependencies not installed (python-telegram-bot)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel, _StreamBuf
|
||||
from nanobot.channels.telegram import TelegramConfig
|
||||
|
||||
|
||||
class _FakeHTTPXRequest:
|
||||
instances: list["_FakeHTTPXRequest"] = []
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.__class__.instances.append(self)
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
cls.instances.clear()
|
||||
|
||||
|
||||
class _FakeUpdater:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self._on_start_polling = on_start_polling
|
||||
|
||||
async def start_polling(self, **kwargs) -> None:
|
||||
self._on_start_polling()
|
||||
|
||||
|
||||
class _FakeBot:
|
||||
def __init__(self) -> None:
|
||||
self.sent_messages: list[dict] = []
|
||||
self.sent_media: list[dict] = []
|
||||
self.get_me_calls = 0
|
||||
|
||||
async def get_me(self):
|
||||
self.get_me_calls += 1
|
||||
return SimpleNamespace(id=999, username="nanobot_test")
|
||||
|
||||
async def set_my_commands(self, commands) -> None:
|
||||
self.commands = commands
|
||||
|
||||
async def send_message(self, **kwargs):
|
||||
self.sent_messages.append(kwargs)
|
||||
return SimpleNamespace(message_id=len(self.sent_messages))
|
||||
|
||||
async def send_photo(self, **kwargs) -> None:
|
||||
self.sent_media.append({"kind": "photo", **kwargs})
|
||||
|
||||
async def send_voice(self, **kwargs) -> None:
|
||||
self.sent_media.append({"kind": "voice", **kwargs})
|
||||
|
||||
async def send_audio(self, **kwargs) -> None:
|
||||
self.sent_media.append({"kind": "audio", **kwargs})
|
||||
|
||||
async def send_document(self, **kwargs) -> None:
|
||||
self.sent_media.append({"kind": "document", **kwargs})
|
||||
|
||||
async def send_chat_action(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def get_file(self, file_id: str):
|
||||
"""Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
|
||||
async def _fake_download(path) -> None:
|
||||
pass
|
||||
return SimpleNamespace(download_to_drive=_fake_download)
|
||||
|
||||
|
||||
class _FakeApp:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self.bot = _FakeBot()
|
||||
self.updater = _FakeUpdater(on_start_polling)
|
||||
self.handlers = []
|
||||
self.error_handlers = []
|
||||
|
||||
def add_error_handler(self, handler) -> None:
|
||||
self.error_handlers.append(handler)
|
||||
|
||||
def add_handler(self, handler) -> None:
|
||||
self.handlers.append(handler)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeBuilder:
|
||||
def __init__(self, app: _FakeApp) -> None:
|
||||
self.app = app
|
||||
self.token_value = None
|
||||
self.request_value = None
|
||||
self.get_updates_request_value = None
|
||||
|
||||
def token(self, token: str):
|
||||
self.token_value = token
|
||||
return self
|
||||
|
||||
def request(self, request):
|
||||
self.request_value = request
|
||||
return self
|
||||
|
||||
def get_updates_request(self, request):
|
||||
self.get_updates_request_value = request
|
||||
return self
|
||||
|
||||
def proxy(self, _proxy):
|
||||
raise AssertionError("builder.proxy should not be called when request is set")
|
||||
|
||||
def get_updates_proxy(self, _proxy):
|
||||
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
|
||||
|
||||
def build(self):
|
||||
return self.app
|
||||
|
||||
|
||||
def _make_telegram_update(
|
||||
*,
|
||||
chat_type: str = "group",
|
||||
text: str | None = None,
|
||||
caption: str | None = None,
|
||||
entities=None,
|
||||
caption_entities=None,
|
||||
reply_to_message=None,
|
||||
):
|
||||
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type=chat_type, is_forum=False),
|
||||
chat_id=-100123,
|
||||
text=text,
|
||||
caption=caption,
|
||||
entities=entities or [],
|
||||
caption_entities=caption_entities or [],
|
||||
reply_to_message=reply_to_message,
|
||||
photo=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
media_group_id=None,
|
||||
message_thread_id=None,
|
||||
message_id=1,
|
||||
)
|
||||
return SimpleNamespace(message=message, effective_user=user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_separate_pools_with_proxy(monkeypatch) -> None:
|
||||
_FakeHTTPXRequest.clear()
|
||||
config = TelegramConfig(
|
||||
enabled=True,
|
||||
token="123:abc",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(config, bus)
|
||||
app = _FakeApp(lambda: setattr(channel, "_running", False))
|
||||
builder = _FakeBuilder(app)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.Application",
|
||||
SimpleNamespace(builder=lambda: builder),
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert len(_FakeHTTPXRequest.instances) == 2
|
||||
api_req, poll_req = _FakeHTTPXRequest.instances
|
||||
assert api_req.kwargs["proxy"] == config.proxy
|
||||
assert poll_req.kwargs["proxy"] == config.proxy
|
||||
assert api_req.kwargs["connection_pool_size"] == 32
|
||||
assert poll_req.kwargs["connection_pool_size"] == 4
|
||||
assert builder.request_value is api_req
|
||||
assert builder.get_updates_request_value is poll_req
|
||||
assert any(cmd.command == "status" for cmd in app.bot.commands)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_respects_custom_pool_config(monkeypatch) -> None:
|
||||
_FakeHTTPXRequest.clear()
|
||||
config = TelegramConfig(
|
||||
enabled=True,
|
||||
token="123:abc",
|
||||
allow_from=["*"],
|
||||
connection_pool_size=32,
|
||||
pool_timeout=10.0,
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(config, bus)
|
||||
app = _FakeApp(lambda: setattr(channel, "_running", False))
|
||||
builder = _FakeBuilder(app)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.Application",
|
||||
SimpleNamespace(builder=lambda: builder),
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
|
||||
api_req = _FakeHTTPXRequest.instances[0]
|
||||
poll_req = _FakeHTTPXRequest.instances[1]
|
||||
assert api_req.kwargs["connection_pool_size"] == 32
|
||||
assert api_req.kwargs["pool_timeout"] == 10.0
|
||||
assert poll_req.kwargs["pool_timeout"] == 10.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_retries_on_timeout() -> None:
|
||||
"""_send_text retries on TimedOut before succeeding."""
|
||||
from telegram.error import TimedOut
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
call_count = 0
|
||||
original_send = channel._app.bot.send_message
|
||||
|
||||
async def flaky_send(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
raise TimedOut()
|
||||
return await original_send(**kwargs)
|
||||
|
||||
channel._app.bot.send_message = flaky_send
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
assert call_count == 3
|
||||
assert len(channel._app.bot.sent_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_gives_up_after_max_retries() -> None:
|
||||
"""_send_text raises TimedOut after exhausting all retries."""
|
||||
from telegram.error import TimedOut
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
async def always_timeout(**kwargs):
|
||||
raise TimedOut()
|
||||
|
||||
channel._app.bot.send_message = always_timeout
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
with pytest.raises(TimedOut):
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
assert channel._app.bot.sent_messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_error_logs_network_issues_as_warning(monkeypatch) -> None:
|
||||
from telegram.error import NetworkError
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
recorded: list[tuple[str, str]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.logger.warning",
|
||||
lambda message, error: recorded.append(("warning", message.format(error))),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.logger.error",
|
||||
lambda message, error: recorded.append(("error", message.format(error))),
|
||||
)
|
||||
|
||||
await channel._on_error(object(), SimpleNamespace(error=NetworkError("proxy disconnected")))
|
||||
|
||||
assert recorded == [("warning", "Telegram network issue: proxy disconnected")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_error_keeps_non_network_exceptions_as_error(monkeypatch) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
recorded: list[tuple[str, str]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.logger.warning",
|
||||
lambda message, error: recorded.append(("warning", message.format(error))),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.logger.error",
|
||||
lambda message, error: recorded.append(("error", message.format(error))),
|
||||
)
|
||||
|
||||
await channel._on_error(object(), SimpleNamespace(error=RuntimeError("boom")))
|
||||
|
||||
assert recorded == [("error", "Telegram error: boom")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_raises_and_keeps_buffer_on_failure() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await channel.send_delta("123", "", {"_stream_end": True})
|
||||
|
||||
assert "123" in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_treats_not_modified_as_success() -> None:
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified"))
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0")
|
||||
|
||||
await channel.send_delta("123", "", {"_stream_end": True, "_stream_id": "s:0"})
|
||||
|
||||
assert "123" not in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_new_stream_id_replaces_stale_buffer() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._stream_bufs["123"] = _StreamBuf(
|
||||
text="hello",
|
||||
message_id=7,
|
||||
last_edit=0.0,
|
||||
stream_id="old:0",
|
||||
)
|
||||
|
||||
await channel.send_delta("123", "world", {"_stream_delta": True, "_stream_id": "new:0"})
|
||||
|
||||
buf = channel._stream_bufs["123"]
|
||||
assert buf.text == "world"
|
||||
assert buf.stream_id == "new:0"
|
||||
assert buf.message_id == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_incremental_edit_treats_not_modified_as_success() -> None:
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0")
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified"))
|
||||
|
||||
await channel.send_delta("123", "", {"_stream_delta": True, "_stream_id": "s:0"})
|
||||
|
||||
assert channel._stream_bufs["123"].last_edit > 0.0
|
||||
|
||||
|
||||
def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type="supergroup"),
|
||||
chat_id=-100123,
|
||||
message_thread_id=42,
|
||||
)
|
||||
|
||||
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
|
||||
|
||||
|
||||
def test_get_extension_falls_back_to_original_filename() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(), MessageBus())
|
||||
|
||||
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
|
||||
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
|
||||
|
||||
|
||||
def test_telegram_group_policy_defaults_to_mention() -> None:
|
||||
assert TelegramConfig().group_policy == "mention"
|
||||
|
||||
|
||||
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("12345|carol") is True
|
||||
assert channel.is_allowed("99999|alice") is True
|
||||
assert channel.is_allowed("67890|bob") is True
|
||||
|
||||
|
||||
def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("attacker|alice|extra") is False
|
||||
assert channel.is_allowed("not-a-number|alice") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_keeps_message_in_topic() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"_progress": True, "message_thread_id": 42},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._message_threads[("123", 10)] = 42
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"message_id": 10},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_remote_media_url_after_security_validation(monkeypatch) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
monkeypatch.setattr("nanobot.channels.telegram.validate_url_target", lambda url: (True, ""))
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="",
|
||||
media=["https://example.com/cat.jpg"],
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_media == [
|
||||
{
|
||||
"kind": "photo",
|
||||
"chat_id": 123,
|
||||
"photo": "https://example.com/cat.jpg",
|
||||
"reply_parameters": None,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_blocks_unsafe_remote_media_url(monkeypatch) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.validate_url_target",
|
||||
lambda url: (False, "Blocked: example.com resolves to private/internal address 127.0.0.1"),
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="",
|
||||
media=["http://example.com/internal.jpg"],
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_media == []
|
||||
assert channel._app.bot.sent_messages == [
|
||||
{
|
||||
"chat_id": 123,
|
||||
"text": "[Failed to send: internal.jpg]",
|
||||
"reply_parameters": None,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
|
||||
|
||||
assert handled == []
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
|
||||
|
||||
assert len(handled) == 2
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_caption_mention() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(
|
||||
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
|
||||
None,
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "@nanobot_test photo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
|
||||
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_open_accepts_plain_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello group"), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert channel._app.bot.get_me_calls == 0
|
||||
|
||||
|
||||
def test_extract_reply_context_no_reply() -> None:
|
||||
"""When there is no reply_to_message, _extract_reply_context returns None."""
|
||||
message = SimpleNamespace(reply_to_message=None)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
def test_extract_reply_context_with_text() -> None:
|
||||
"""When reply has text, return prefixed string."""
|
||||
reply = SimpleNamespace(text="Hello world", caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
|
||||
|
||||
|
||||
def test_extract_reply_context_with_caption_only() -> None:
|
||||
"""When reply has only caption (no text), caption is used."""
|
||||
reply = SimpleNamespace(text=None, caption="Photo caption")
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
|
||||
|
||||
|
||||
def test_extract_reply_context_truncation() -> None:
|
||||
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
|
||||
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
|
||||
reply = SimpleNamespace(text=long_text, caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
result = TelegramChannel._extract_reply_context(message)
|
||||
assert result is not None
|
||||
assert result.startswith("[Reply to: ")
|
||||
assert result.endswith("...]")
|
||||
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
|
||||
|
||||
|
||||
def test_extract_reply_context_no_text_returns_none() -> None:
|
||||
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
|
||||
reply = SimpleNamespace(text=None, caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_includes_reply_context() -> None:
|
||||
"""When user replies to a message, content passed to bus starts with reply context."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
|
||||
update = _make_telegram_update(text="translate this", reply_to_message=reply)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"].startswith("[Reply to: Hello]")
|
||||
assert "translate this" in handled[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_message_media_returns_path_when_download_succeeds(
|
||||
monkeypatch, tmp_path
|
||||
) -> None:
|
||||
"""_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
|
||||
msg = SimpleNamespace(
|
||||
photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
paths, parts = await channel._download_message_media(msg)
|
||||
assert len(paths) == 1
|
||||
assert len(parts) == 1
|
||||
assert "fid123" in paths[0]
|
||||
assert "[image:" in parts[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_message_media_uses_file_unique_id_when_available(
|
||||
monkeypatch, tmp_path
|
||||
) -> None:
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
downloaded: dict[str, str] = {}
|
||||
|
||||
async def _download_to_drive(path: str) -> None:
|
||||
downloaded["path"] = path
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=_download_to_drive)
|
||||
)
|
||||
channel._app = app
|
||||
|
||||
msg = SimpleNamespace(
|
||||
photo=[
|
||||
SimpleNamespace(
|
||||
file_id="file-id-that-should-not-be-used",
|
||||
file_unique_id="stable-unique-id",
|
||||
mime_type="image/jpeg",
|
||||
file_name=None,
|
||||
)
|
||||
],
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
|
||||
paths, parts = await channel._download_message_media(msg)
|
||||
|
||||
assert downloaded["path"].endswith("stable-unique-id.jpg")
|
||||
assert paths == [str(media_dir / "stable-unique-id.jpg")]
|
||||
assert parts == [f"[image: {media_dir / 'stable-unique-id.jpg'}]"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
|
||||
"""When user replies to a message with media, that media is downloaded and attached to the turn."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
channel._app = app
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption=None,
|
||||
photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(
|
||||
text="what is the image?",
|
||||
reply_to_message=reply_with_photo,
|
||||
)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"].startswith("[Reply to: [image:")
|
||||
assert "what is the image?" in handled[0]["content"]
|
||||
assert len(handled[0]["media"]) == 1
|
||||
assert "reply_photo_fid" in handled[0]["media"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
|
||||
"""When reply has media but download fails, no media attached and no reply tag."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.get_file = None
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption=None,
|
||||
photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert "what is this?" in handled[0]["content"]
|
||||
assert handled[0]["media"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
|
||||
"""When replying to a message with caption + photo, both text context and media are included."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
channel._app = app
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_caption_and_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption="A cute cat",
|
||||
photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(
|
||||
text="what breed is this?",
|
||||
reply_to_message=reply_with_caption_and_photo,
|
||||
)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert "[Reply to: A cute cat]" in handled[0]["content"]
|
||||
assert "what breed is this?" in handled[0]["content"]
|
||||
assert len(handled[0]["media"]) == 1
|
||||
assert "cat_fid" in handled[0]["media"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_command_does_not_inject_reply_context() -> None:
|
||||
"""Slash commands forwarded via _forward_command must not include reply context."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
|
||||
reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
|
||||
update = _make_telegram_update(text="/new", reply_to_message=reply)
|
||||
await channel._forward_command(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_help_includes_restart_command() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
update = _make_telegram_update(text="/help", chat_type="private")
|
||||
update.message.reply_text = AsyncMock()
|
||||
|
||||
await channel._on_help(update, None)
|
||||
|
||||
update.message.reply_text.assert_awaited_once()
|
||||
help_text = update.message.reply_text.await_args.args[0]
|
||||
assert "/restart" in help_text
|
||||
assert "/status" in help_text
|
||||
@@ -0,0 +1,280 @@
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.weixin import (
|
||||
ITEM_IMAGE,
|
||||
ITEM_TEXT,
|
||||
MESSAGE_TYPE_BOT,
|
||||
WEIXIN_CHANNEL_VERSION,
|
||||
WeixinChannel,
|
||||
WeixinConfig,
|
||||
)
|
||||
|
||||
|
||||
def _make_channel() -> tuple[WeixinChannel, MessageBus]:
|
||||
bus = MessageBus()
|
||||
channel = WeixinChannel(
|
||||
WeixinConfig(
|
||||
enabled=True,
|
||||
allow_from=["*"],
|
||||
state_dir=tempfile.mkdtemp(prefix="nanobot-weixin-test-"),
|
||||
),
|
||||
bus,
|
||||
)
|
||||
return channel, bus
|
||||
|
||||
|
||||
def test_make_headers_includes_route_tag_when_configured() -> None:
|
||||
bus = MessageBus()
|
||||
channel = WeixinChannel(
|
||||
WeixinConfig(enabled=True, allow_from=["*"], route_tag=123),
|
||||
bus,
|
||||
)
|
||||
channel._token = "token"
|
||||
|
||||
headers = channel._make_headers()
|
||||
|
||||
assert headers["Authorization"] == "Bearer token"
|
||||
assert headers["SKRouteTag"] == "123"
|
||||
|
||||
|
||||
def test_channel_version_matches_reference_plugin_version() -> None:
|
||||
assert WEIXIN_CHANNEL_VERSION == "1.0.3"
|
||||
|
||||
|
||||
def test_save_and_load_state_persists_context_tokens(tmp_path) -> None:
|
||||
bus = MessageBus()
|
||||
channel = WeixinChannel(
|
||||
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
|
||||
bus,
|
||||
)
|
||||
channel._token = "token"
|
||||
channel._get_updates_buf = "cursor"
|
||||
channel._context_tokens = {"wx-user": "ctx-1"}
|
||||
|
||||
channel._save_state()
|
||||
|
||||
saved = json.loads((tmp_path / "account.json").read_text())
|
||||
assert saved["context_tokens"] == {"wx-user": "ctx-1"}
|
||||
|
||||
restored = WeixinChannel(
|
||||
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
|
||||
bus,
|
||||
)
|
||||
|
||||
assert restored._load_state() is True
|
||||
assert restored._context_tokens == {"wx-user": "ctx-1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_deduplicates_inbound_ids() -> None:
|
||||
channel, bus = _make_channel()
|
||||
msg = {
|
||||
"message_type": 1,
|
||||
"message_id": "m1",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-1",
|
||||
"item_list": [
|
||||
{"type": ITEM_TEXT, "text_item": {"text": "hello"}},
|
||||
],
|
||||
}
|
||||
|
||||
await channel._process_message(msg)
|
||||
first = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
|
||||
await channel._process_message(msg)
|
||||
|
||||
assert first.sender_id == "wx-user"
|
||||
assert first.chat_id == "wx-user"
|
||||
assert first.content == "hello"
|
||||
assert bus.inbound_size == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_caches_context_token_and_send_uses_it() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": 1,
|
||||
"message_id": "m2",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-2",
|
||||
"item_list": [
|
||||
{"type": ITEM_TEXT, "text_item": {"text": "ping"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
|
||||
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None:
|
||||
bus = MessageBus()
|
||||
channel = WeixinChannel(
|
||||
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
|
||||
bus,
|
||||
)
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": 1,
|
||||
"message_id": "m2b",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-2b",
|
||||
"item_list": [
|
||||
{"type": ITEM_TEXT, "text_item": {"text": "ping"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
saved = json.loads((tmp_path / "account.json").read_text())
|
||||
assert saved["context_tokens"] == {"wx-user": "ctx-2b"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_extracts_media_and_preserves_paths() -> None:
|
||||
channel, bus = _make_channel()
|
||||
channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg")
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": 1,
|
||||
"message_id": "m3",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-3",
|
||||
"item_list": [
|
||||
{"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "x"}}},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
|
||||
|
||||
assert "[image]" in inbound.content
|
||||
assert "/tmp/test.jpg" in inbound.content
|
||||
assert inbound.media == ["/tmp/test.jpg"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_without_context_token_does_not_send_text() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "unknown-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_does_not_send_when_session_is_paused() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-2"
|
||||
channel._pause_session(60)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_pauses_session_on_expired_errcode() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = SimpleNamespace(timeout=None)
|
||||
channel._token = "token"
|
||||
channel._api_post = AsyncMock(return_value={"ret": 0, "errcode": -14, "errmsg": "expired"})
|
||||
|
||||
await channel._poll_once()
|
||||
|
||||
assert channel._session_pause_remaining_s() > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._save_state = lambda: None
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._api_get = AsyncMock(
|
||||
side_effect=[
|
||||
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-2",
|
||||
"ilink_bot_id": "bot-2",
|
||||
"baseurl": "https://example.test",
|
||||
"ilink_user_id": "wx-user",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is True
|
||||
assert channel._token == "token-2"
|
||||
assert channel.config.base_url == "https://example.test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._api_get = AsyncMock(
|
||||
side_effect=[
|
||||
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-3", "qrcode_img_content": "url-3"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-4", "qrcode_img_content": "url-4"},
|
||||
{"status": "expired"},
|
||||
]
|
||||
)
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_skips_bot_messages() -> None:
|
||||
channel, bus = _make_channel()
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": MESSAGE_TYPE_BOT,
|
||||
"message_id": "m4",
|
||||
"from_user_id": "wx-user",
|
||||
"item_list": [
|
||||
{"type": ITEM_TEXT, "text_item": {"text": "hello"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
assert bus.inbound_size == 0
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Tests for WhatsApp channel outbound media support."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.whatsapp import WhatsAppChannel
|
||||
|
||||
|
||||
def _make_channel() -> WhatsAppChannel:
|
||||
bus = MagicMock()
|
||||
ch = WhatsAppChannel({"enabled": True}, bus)
|
||||
ch._ws = AsyncMock()
|
||||
ch._connected = True
|
||||
return ch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_only():
|
||||
ch = _make_channel()
|
||||
msg = OutboundMessage(channel="whatsapp", chat_id="123@s.whatsapp.net", content="hello")
|
||||
|
||||
await ch.send(msg)
|
||||
|
||||
ch._ws.send.assert_called_once()
|
||||
payload = json.loads(ch._ws.send.call_args[0][0])
|
||||
assert payload["type"] == "send"
|
||||
assert payload["text"] == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_dispatches_send_media_command():
|
||||
ch = _make_channel()
|
||||
msg = OutboundMessage(
|
||||
channel="whatsapp",
|
||||
chat_id="123@s.whatsapp.net",
|
||||
content="check this out",
|
||||
media=["/tmp/photo.jpg"],
|
||||
)
|
||||
|
||||
await ch.send(msg)
|
||||
|
||||
assert ch._ws.send.call_count == 2
|
||||
text_payload = json.loads(ch._ws.send.call_args_list[0][0][0])
|
||||
media_payload = json.loads(ch._ws.send.call_args_list[1][0][0])
|
||||
|
||||
assert text_payload["type"] == "send"
|
||||
assert text_payload["text"] == "check this out"
|
||||
|
||||
assert media_payload["type"] == "send_media"
|
||||
assert media_payload["filePath"] == "/tmp/photo.jpg"
|
||||
assert media_payload["mimetype"] == "image/jpeg"
|
||||
assert media_payload["fileName"] == "photo.jpg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_only_no_text():
|
||||
ch = _make_channel()
|
||||
msg = OutboundMessage(
|
||||
channel="whatsapp",
|
||||
chat_id="123@s.whatsapp.net",
|
||||
content="",
|
||||
media=["/tmp/doc.pdf"],
|
||||
)
|
||||
|
||||
await ch.send(msg)
|
||||
|
||||
ch._ws.send.assert_called_once()
|
||||
payload = json.loads(ch._ws.send.call_args[0][0])
|
||||
assert payload["type"] == "send_media"
|
||||
assert payload["mimetype"] == "application/pdf"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_multiple_media():
|
||||
ch = _make_channel()
|
||||
msg = OutboundMessage(
|
||||
channel="whatsapp",
|
||||
chat_id="123@s.whatsapp.net",
|
||||
content="",
|
||||
media=["/tmp/a.png", "/tmp/b.mp4"],
|
||||
)
|
||||
|
||||
await ch.send(msg)
|
||||
|
||||
assert ch._ws.send.call_count == 2
|
||||
p1 = json.loads(ch._ws.send.call_args_list[0][0][0])
|
||||
p2 = json.loads(ch._ws.send.call_args_list[1][0][0])
|
||||
assert p1["mimetype"] == "image/png"
|
||||
assert p2["mimetype"] == "video/mp4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_when_disconnected_is_noop():
|
||||
ch = _make_channel()
|
||||
ch._connected = False
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="whatsapp",
|
||||
chat_id="123@s.whatsapp.net",
|
||||
content="hello",
|
||||
media=["/tmp/x.jpg"],
|
||||
)
|
||||
await ch.send(msg)
|
||||
|
||||
ch._ws.send.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_skips_unmentioned_group_message():
|
||||
ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock())
|
||||
ch._handle_message = AsyncMock()
|
||||
|
||||
await ch._handle_bridge_message(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "message",
|
||||
"id": "m1",
|
||||
"sender": "12345@g.us",
|
||||
"pn": "user@s.whatsapp.net",
|
||||
"content": "hello group",
|
||||
"timestamp": 1,
|
||||
"isGroup": True,
|
||||
"wasMentioned": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
ch._handle_message.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_mentioned_group_message():
|
||||
ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock())
|
||||
ch._handle_message = AsyncMock()
|
||||
|
||||
await ch._handle_bridge_message(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "message",
|
||||
"id": "m1",
|
||||
"sender": "12345@g.us",
|
||||
"pn": "user@s.whatsapp.net",
|
||||
"content": "hello @bot",
|
||||
"timestamp": 1,
|
||||
"isGroup": True,
|
||||
"wasMentioned": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
ch._handle_message.assert_awaited_once()
|
||||
kwargs = ch._handle_message.await_args.kwargs
|
||||
assert kwargs["chat_id"] == "12345@g.us"
|
||||
assert kwargs["sender_id"] == "user"
|
||||
@@ -0,0 +1,147 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
|
||||
from nanobot.cli import commands
|
||||
from nanobot.cli import stream as stream_mod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_session():
|
||||
"""Mock the global prompt session."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.prompt_async = AsyncMock()
|
||||
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session), \
|
||||
patch("nanobot.cli.commands.patch_stdout"):
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_interactive_input_async_returns_input(mock_prompt_session):
|
||||
"""Test that _read_interactive_input_async returns the user input from prompt_session."""
|
||||
mock_prompt_session.prompt_async.return_value = "hello world"
|
||||
|
||||
result = await commands._read_interactive_input_async()
|
||||
|
||||
assert result == "hello world"
|
||||
mock_prompt_session.prompt_async.assert_called_once()
|
||||
args, _ = mock_prompt_session.prompt_async.call_args
|
||||
assert isinstance(args[0], HTML) # Verify HTML prompt is used
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_interactive_input_async_handles_eof(mock_prompt_session):
|
||||
"""Test that EOFError converts to KeyboardInterrupt."""
|
||||
mock_prompt_session.prompt_async.side_effect = EOFError()
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
await commands._read_interactive_input_async()
|
||||
|
||||
|
||||
def test_init_prompt_session_creates_session():
|
||||
"""Test that _init_prompt_session initializes the global session."""
|
||||
# Ensure global is None before test
|
||||
commands._PROMPT_SESSION = None
|
||||
|
||||
with patch("nanobot.cli.commands.PromptSession") as MockSession, \
|
||||
patch("nanobot.cli.commands.FileHistory") as MockHistory, \
|
||||
patch("pathlib.Path.home") as mock_home:
|
||||
|
||||
mock_home.return_value = MagicMock()
|
||||
|
||||
commands._init_prompt_session()
|
||||
|
||||
assert commands._PROMPT_SESSION is not None
|
||||
MockSession.assert_called_once()
|
||||
_, kwargs = MockSession.call_args
|
||||
assert kwargs["multiline"] is False
|
||||
assert kwargs["enable_open_in_editor"] is False
|
||||
|
||||
|
||||
def test_thinking_spinner_pause_stops_and_restarts():
|
||||
"""Pause should stop the active spinner and restart it afterward."""
|
||||
spinner = MagicMock()
|
||||
mock_console = MagicMock()
|
||||
mock_console.status.return_value = spinner
|
||||
|
||||
thinking = stream_mod.ThinkingSpinner(console=mock_console)
|
||||
with thinking:
|
||||
with thinking.pause():
|
||||
pass
|
||||
|
||||
assert spinner.method_calls == [
|
||||
call.start(),
|
||||
call.stop(),
|
||||
call.start(),
|
||||
call.stop(),
|
||||
]
|
||||
|
||||
|
||||
def test_print_cli_progress_line_pauses_spinner_before_printing():
|
||||
"""CLI progress output should pause spinner to avoid garbled lines."""
|
||||
order: list[str] = []
|
||||
spinner = MagicMock()
|
||||
spinner.start.side_effect = lambda: order.append("start")
|
||||
spinner.stop.side_effect = lambda: order.append("stop")
|
||||
mock_console = MagicMock()
|
||||
mock_console.status.return_value = spinner
|
||||
|
||||
with patch.object(commands.console, "print", side_effect=lambda *_args, **_kwargs: order.append("print")):
|
||||
thinking = stream_mod.ThinkingSpinner(console=mock_console)
|
||||
with thinking:
|
||||
commands._print_cli_progress_line("tool running", thinking)
|
||||
|
||||
assert order == ["start", "stop", "print", "start", "stop"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_print_interactive_progress_line_pauses_spinner_before_printing():
|
||||
"""Interactive progress output should also pause spinner cleanly."""
|
||||
order: list[str] = []
|
||||
spinner = MagicMock()
|
||||
spinner.start.side_effect = lambda: order.append("start")
|
||||
spinner.stop.side_effect = lambda: order.append("stop")
|
||||
mock_console = MagicMock()
|
||||
mock_console.status.return_value = spinner
|
||||
|
||||
async def fake_print(_text: str) -> None:
|
||||
order.append("print")
|
||||
|
||||
with patch("nanobot.cli.commands._print_interactive_line", side_effect=fake_print):
|
||||
thinking = stream_mod.ThinkingSpinner(console=mock_console)
|
||||
with thinking:
|
||||
await commands._print_interactive_progress_line("tool running", thinking)
|
||||
|
||||
assert order == ["start", "stop", "print", "start", "stop"]
|
||||
|
||||
|
||||
def test_response_renderable_uses_text_for_explicit_plain_rendering():
|
||||
status = (
|
||||
"🐈 nanobot v0.1.4.post5\n"
|
||||
"🧠 Model: MiniMax-M2.7\n"
|
||||
"📊 Tokens: 20639 in / 29 out"
|
||||
)
|
||||
|
||||
renderable = commands._response_renderable(
|
||||
status,
|
||||
render_markdown=True,
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
|
||||
assert renderable.__class__.__name__ == "Text"
|
||||
|
||||
|
||||
def test_response_renderable_preserves_normal_markdown_rendering():
|
||||
renderable = commands._response_renderable("**bold**", render_markdown=True)
|
||||
|
||||
assert renderable.__class__.__name__ == "Markdown"
|
||||
|
||||
|
||||
def test_response_renderable_without_metadata_keeps_markdown_path():
|
||||
help_text = "🐈 nanobot commands:\n/status — Show bot status\n/help — Show available commands"
|
||||
|
||||
renderable = commands._response_renderable(help_text, render_markdown=True)
|
||||
|
||||
assert renderable.__class__.__name__ == "Markdown"
|
||||
@@ -0,0 +1,905 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.cli.commands import _make_provider, app
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class _StopGatewayError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paths():
|
||||
"""Mock config/workspace paths for test isolation."""
|
||||
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
|
||||
patch("nanobot.config.loader.save_config") as mock_sc, \
|
||||
patch("nanobot.config.loader.load_config") as mock_lc, \
|
||||
patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
|
||||
|
||||
base_dir = Path("./test_onboard_data")
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
base_dir.mkdir()
|
||||
|
||||
config_file = base_dir / "config.json"
|
||||
workspace_dir = base_dir / "workspace"
|
||||
|
||||
mock_cp.return_value = config_file
|
||||
mock_ws.return_value = workspace_dir
|
||||
mock_lc.side_effect = lambda _config_path=None: Config()
|
||||
|
||||
def _save_config(config: Config, config_path: Path | None = None):
|
||||
target = config_path or config_file
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
target.write_text(json.dumps(config.model_dump(by_alias=True)), encoding="utf-8")
|
||||
|
||||
mock_sc.side_effect = _save_config
|
||||
|
||||
yield config_file, workspace_dir, mock_ws
|
||||
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
|
||||
|
||||
def test_onboard_fresh_install(mock_paths):
|
||||
"""No existing config — should create from scratch."""
|
||||
config_file, workspace_dir, mock_ws = mock_paths
|
||||
|
||||
result = runner.invoke(app, ["onboard"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Created config" in result.stdout
|
||||
assert "Created workspace" in result.stdout
|
||||
assert "nanobot is ready" in result.stdout
|
||||
assert config_file.exists()
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
assert (workspace_dir / "memory" / "MEMORY.md").exists()
|
||||
expected_workspace = Config().workspace_path
|
||||
assert mock_ws.call_args.args == (expected_workspace,)
|
||||
|
||||
|
||||
def test_onboard_existing_config_refresh(mock_paths):
|
||||
"""Config exists, user declines overwrite — should refresh (load-merge-save)."""
|
||||
config_file, workspace_dir, _ = mock_paths
|
||||
config_file.write_text('{"existing": true}')
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Config already exists" in result.stdout
|
||||
assert "existing values preserved" in result.stdout
|
||||
assert workspace_dir.exists()
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
|
||||
|
||||
def test_onboard_existing_config_overwrite(mock_paths):
|
||||
"""Config exists, user confirms overwrite — should reset to defaults."""
|
||||
config_file, workspace_dir, _ = mock_paths
|
||||
config_file.write_text('{"existing": true}')
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="y\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Config already exists" in result.stdout
|
||||
assert "Config reset to defaults" in result.stdout
|
||||
assert workspace_dir.exists()
|
||||
|
||||
|
||||
def test_onboard_existing_workspace_safe_create(mock_paths):
|
||||
"""Workspace exists — should not recreate, but still add missing templates."""
|
||||
config_file, workspace_dir, _ = mock_paths
|
||||
workspace_dir.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Created workspace" not in result.stdout
|
||||
assert "Created AGENTS.md" in result.stdout
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
|
||||
|
||||
def _strip_ansi(text):
|
||||
"""Remove ANSI escape codes from text."""
|
||||
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
|
||||
return ansi_escape.sub('', text)
|
||||
|
||||
|
||||
def test_onboard_help_shows_workspace_and_config_options():
|
||||
result = runner.invoke(app, ["onboard", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
stripped_output = _strip_ansi(result.stdout)
|
||||
assert "--workspace" in stripped_output
|
||||
assert "-w" in stripped_output
|
||||
assert "--config" in stripped_output
|
||||
assert "-c" in stripped_output
|
||||
assert "--wizard" in stripped_output
|
||||
assert "--dir" not in stripped_output
|
||||
|
||||
|
||||
def test_onboard_interactive_discard_does_not_save_or_create_workspace(mock_paths, monkeypatch):
|
||||
config_file, workspace_dir, _ = mock_paths
|
||||
|
||||
from nanobot.cli.onboard import OnboardResult
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.onboard.run_onboard",
|
||||
lambda initial_config: OnboardResult(config=initial_config, should_save=False),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["onboard", "--wizard"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No changes were saved" in result.stdout
|
||||
assert not config_file.exists()
|
||||
assert not workspace_dir.exists()
|
||||
|
||||
|
||||
def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "instance" / "config.json"
|
||||
workspace_path = tmp_path / "workspace"
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["onboard", "--config", str(config_path), "--workspace", str(workspace_path)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
saved = Config.model_validate(json.loads(config_path.read_text(encoding="utf-8")))
|
||||
assert saved.workspace_path == workspace_path
|
||||
assert (workspace_path / "AGENTS.md").exists()
|
||||
stripped_output = _strip_ansi(result.stdout)
|
||||
compact_output = stripped_output.replace("\n", "")
|
||||
resolved_config = str(config_path.resolve())
|
||||
assert resolved_config in compact_output
|
||||
assert f"--config {resolved_config}" in compact_output
|
||||
|
||||
|
||||
def test_onboard_wizard_preserves_explicit_config_in_next_steps(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "instance" / "config.json"
|
||||
workspace_path = tmp_path / "workspace"
|
||||
|
||||
from nanobot.cli.onboard import OnboardResult
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.onboard.run_onboard",
|
||||
lambda initial_config: OnboardResult(config=initial_config, should_save=True),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["onboard", "--wizard", "--config", str(config_path), "--workspace", str(workspace_path)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
stripped_output = _strip_ansi(result.stdout)
|
||||
compact_output = stripped_output.replace("\n", "")
|
||||
resolved_config = str(config_path.resolve())
|
||||
assert f'nanobot agent -m "Hello!" --config {resolved_config}' in compact_output
|
||||
assert f"nanobot gateway --config {resolved_config}" in compact_output
|
||||
|
||||
|
||||
def test_config_matches_github_copilot_codex_with_hyphen_prefix():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "github-copilot/gpt-5.3-codex"
|
||||
|
||||
assert config.get_provider_name() == "github_copilot"
|
||||
|
||||
|
||||
def test_config_matches_openai_codex_with_hyphen_prefix():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "openai-codex/gpt-5.1-codex"
|
||||
|
||||
assert config.get_provider_name() == "openai_codex"
|
||||
|
||||
|
||||
def test_config_dump_excludes_oauth_provider_blocks():
|
||||
config = Config()
|
||||
|
||||
providers = config.model_dump(by_alias=True)["providers"]
|
||||
|
||||
assert "openaiCodex" not in providers
|
||||
assert "githubCopilot" not in providers
|
||||
|
||||
|
||||
def test_config_matches_explicit_ollama_prefix_without_api_key():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "ollama/llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
||||
config = Config()
|
||||
config.agents.defaults.provider = "ollama"
|
||||
config.agents.defaults.model = "llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "volcengineCodingPlan",
|
||||
"model": "doubao-1-5-pro",
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"volcengineCodingPlan": {
|
||||
"apiKey": "test-key",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "volcengine_coding_plan"
|
||||
assert config.get_api_base() == "https://ark.cn-beijing.volces.com/api/coding/v3"
|
||||
|
||||
|
||||
def test_find_by_name_accepts_camel_case_and_hyphen_aliases():
|
||||
assert find_by_name("volcengineCodingPlan") is not None
|
||||
assert find_by_name("volcengineCodingPlan").name == "volcengine_coding_plan"
|
||||
assert find_by_name("github-copilot") is not None
|
||||
assert find_by_name("github-copilot").name == "github_copilot"
|
||||
|
||||
|
||||
def test_config_auto_detects_ollama_from_local_api_base():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {"ollama": {"apiBase": "http://localhost:11434/v1"}},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {
|
||||
"vllm": {"apiBase": "http://localhost:8000"},
|
||||
"ollama": {"apiBase": "http://localhost:11434/v1"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_config_falls_back_to_vllm_when_ollama_not_configured():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {
|
||||
"vllm": {"apiBase": "http://localhost:8000"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "vllm"
|
||||
assert config.get_api_base() == "http://localhost:8000"
|
||||
|
||||
|
||||
def test_openai_compat_provider_passes_model_through():
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider(default_model="github-copilot/gpt-5.3-codex")
|
||||
|
||||
assert provider.get_default_model() == "github-copilot/gpt-5.3-codex"
|
||||
|
||||
|
||||
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
||||
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
|
||||
|
||||
def test_make_provider_passes_extra_headers_to_custom_provider():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "custom", "model": "gpt-4o-mini"}},
|
||||
"providers": {
|
||||
"custom": {
|
||||
"apiKey": "test-key",
|
||||
"apiBase": "https://example.com/v1",
|
||||
"extraHeaders": {
|
||||
"APP-Code": "demo-app",
|
||||
"x-session-affinity": "sticky-session",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai:
|
||||
_make_provider(config)
|
||||
|
||||
kwargs = mock_async_openai.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-key"
|
||||
assert kwargs["base_url"] == "https://example.com/v1"
|
||||
assert kwargs["default_headers"]["APP-Code"] == "demo-app"
|
||||
assert kwargs["default_headers"]["x-session-affinity"] == "sticky-session"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_runtime(tmp_path):
|
||||
"""Mock agent command dependencies for focused CLI tests."""
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "default-workspace")
|
||||
|
||||
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
|
||||
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
|
||||
patch("nanobot.cli.commands._make_provider", return_value=object()), \
|
||||
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
|
||||
patch("nanobot.bus.queue.MessageBus"), \
|
||||
patch("nanobot.cron.service.CronService"), \
|
||||
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
|
||||
|
||||
agent_loop = MagicMock()
|
||||
agent_loop.channels_config = None
|
||||
agent_loop.process_direct = AsyncMock(
|
||||
return_value=OutboundMessage(channel="cli", chat_id="direct", content="mock-response"),
|
||||
)
|
||||
agent_loop.close_mcp = AsyncMock(return_value=None)
|
||||
mock_agent_loop_cls.return_value = agent_loop
|
||||
|
||||
yield {
|
||||
"config": config,
|
||||
"load_config": mock_load_config,
|
||||
"sync_templates": mock_sync_templates,
|
||||
"agent_loop_cls": mock_agent_loop_cls,
|
||||
"agent_loop": agent_loop,
|
||||
"print_response": mock_print_response,
|
||||
}
|
||||
|
||||
|
||||
def test_agent_help_shows_workspace_and_config_options():
|
||||
result = runner.invoke(app, ["agent", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
stripped_output = _strip_ansi(result.stdout)
|
||||
assert "--workspace" in stripped_output
|
||||
assert "-w" in stripped_output
|
||||
assert "--config" in stripped_output
|
||||
assert "-c" in stripped_output
|
||||
|
||||
|
||||
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
|
||||
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (None,)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (
|
||||
mock_agent_runtime["config"].workspace_path,
|
||||
)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == (
|
||||
mock_agent_runtime["config"].workspace_path
|
||||
)
|
||||
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
||||
mock_agent_runtime["print_response"].assert_called_once_with(
|
||||
"mock-response", render_markdown=True, metadata={},
|
||||
)
|
||||
|
||||
|
||||
def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
|
||||
config_path = tmp_path / "agent-config.json"
|
||||
config_path.write_text("{}")
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
|
||||
|
||||
|
||||
def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object())
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["config_path"] == config_file.resolve()
|
||||
|
||||
|
||||
def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "agent-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
|
||||
class _FakeCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json"
|
||||
|
||||
|
||||
def test_agent_workspace_override_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
legacy_file.write_text('{"jobs": []}')
|
||||
|
||||
override = tmp_path / "override-workspace"
|
||||
config = Config()
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _FakeCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["agent", "-m", "hello", "-c", str(config_file), "-w", str(override)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["cron_store"] == override / "cron" / "jobs.json"
|
||||
assert legacy_file.exists()
|
||||
assert not (override / "cron" / "jobs.json").exists()
|
||||
|
||||
|
||||
def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
legacy_file.write_text('{"jobs": []}')
|
||||
|
||||
custom_workspace = tmp_path / "custom-workspace"
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(custom_workspace)
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _FakeCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["cron_store"] == custom_workspace / "cron" / "jobs.json"
|
||||
assert legacy_file.exists()
|
||||
assert not (custom_workspace / "cron" / "jobs.json").exists()
|
||||
|
||||
|
||||
def test_agent_overrides_workspace_path(mock_agent_runtime):
|
||||
workspace_path = Path("/tmp/agent-workspace")
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-w", str(workspace_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
|
||||
|
||||
def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path):
|
||||
config_path = tmp_path / "agent-config.json"
|
||||
config_path.write_text("{}")
|
||||
workspace_path = Path("/tmp/agent-workspace")
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["agent", "-m", "hello", "-c", str(config_path), "-w", str(workspace_path)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
|
||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
|
||||
|
||||
def test_agent_hints_about_deprecated_memory_window(mock_agent_runtime, tmp_path):
|
||||
config_file = tmp_path / "config.json"
|
||||
config_file.write_text(json.dumps({"agents": {"defaults": {"memoryWindow": 42}}}))
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "no longer used" in result.stdout
|
||||
|
||||
|
||||
def test_heartbeat_retains_recent_messages_by_default():
|
||||
config = Config()
|
||||
|
||||
assert config.gateway.heartbeat.keep_recent_messages == 8
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert seen["config_path"] == config_file.resolve()
|
||||
assert seen["workspace"] == Path(config.agents.defaults.workspace)
|
||||
|
||||
|
||||
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
override = tmp_path / "override-workspace"
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["gateway", "--config", str(config_file), "--workspace", str(override)],
|
||||
)
|
||||
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert seen["workspace"] == override
|
||||
assert config.workspace_path == override
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json"
|
||||
|
||||
|
||||
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
legacy_file.write_text('{"jobs": []}')
|
||||
|
||||
override = tmp_path / "override-workspace"
|
||||
config = Config()
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["gateway", "--config", str(config_file), "--workspace", str(override)],
|
||||
)
|
||||
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert seen["cron_store"] == override / "cron" / "jobs.json"
|
||||
assert legacy_file.exists()
|
||||
assert not (override / "cron" / "jobs.json").exists()
|
||||
|
||||
|
||||
def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
legacy_file.write_text('{"jobs": []}')
|
||||
|
||||
custom_workspace = tmp_path / "custom-workspace"
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(custom_workspace)
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert seen["cron_store"] == custom_workspace / "cron" / "jobs.json"
|
||||
assert legacy_file.exists()
|
||||
assert not (custom_workspace / "cron" / "jobs.json").exists()
|
||||
|
||||
|
||||
def test_migrate_cron_store_moves_legacy_file(tmp_path: Path) -> None:
|
||||
"""Legacy global jobs.json is moved into the workspace on first run."""
|
||||
from nanobot.cli.commands import _migrate_cron_store
|
||||
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
legacy_file.write_text('{"jobs": []}')
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "workspace")
|
||||
workspace_cron = config.workspace_path / "cron" / "jobs.json"
|
||||
|
||||
with patch("nanobot.config.paths.get_cron_dir", return_value=legacy_dir):
|
||||
_migrate_cron_store(config)
|
||||
|
||||
assert workspace_cron.exists()
|
||||
assert workspace_cron.read_text() == '{"jobs": []}'
|
||||
assert not legacy_file.exists()
|
||||
|
||||
|
||||
def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) -> None:
|
||||
"""Migration does not overwrite an existing workspace cron store."""
|
||||
from nanobot.cli.commands import _migrate_cron_store
|
||||
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
(legacy_dir / "jobs.json").write_text('{"old": true}')
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "workspace")
|
||||
workspace_cron = config.workspace_path / "cron" / "jobs.json"
|
||||
workspace_cron.parent.mkdir(parents=True)
|
||||
workspace_cron.write_text('{"new": true}')
|
||||
|
||||
with patch("nanobot.config.paths.get_cron_dir", return_value=legacy_dir):
|
||||
_migrate_cron_store(config)
|
||||
|
||||
assert workspace_cron.read_text() == '{"new": true}'
|
||||
|
||||
|
||||
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert "port 18791" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
|
||||
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert "port 18792" in result.stdout
|
||||
|
||||
|
||||
def test_channels_login_requires_channel_name() -> None:
|
||||
result = runner.invoke(app, ["channels", "login"])
|
||||
|
||||
assert result.exit_code == 2
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Tests for /restart slash command."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop():
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
workspace = MagicMock()
|
||||
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager"):
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||
return loop, bus
|
||||
|
||||
|
||||
class TestRestartCommand:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_sends_message_and_calls_execv(self):
|
||||
from nanobot.command.builtin import cmd_restart
|
||||
from nanobot.command.router import CommandContext
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="/restart")
|
||||
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw="/restart", loop=loop)
|
||||
|
||||
with patch("nanobot.command.builtin.os.execv") as mock_execv:
|
||||
out = await cmd_restart(ctx)
|
||||
assert "Restarting" in out.content
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
mock_execv.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_intercepted_in_run_loop(self):
|
||||
"""Verify /restart is handled at the run-loop level, not inside _dispatch."""
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/restart")
|
||||
|
||||
with patch.object(loop, "_dispatch", new_callable=AsyncMock) as mock_dispatch, \
|
||||
patch("nanobot.command.builtin.os.execv"):
|
||||
await bus.publish_inbound(msg)
|
||||
|
||||
loop._running = True
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
await asyncio.sleep(0.1)
|
||||
loop._running = False
|
||||
run_task.cancel()
|
||||
try:
|
||||
await run_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
mock_dispatch.assert_not_called()
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "Restarting" in out.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_intercepted_in_run_loop(self):
|
||||
"""Verify /status is handled at the run-loop level for immediate replies."""
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
|
||||
|
||||
with patch.object(loop, "_dispatch", new_callable=AsyncMock) as mock_dispatch:
|
||||
await bus.publish_inbound(msg)
|
||||
|
||||
loop._running = True
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
await asyncio.sleep(0.1)
|
||||
loop._running = False
|
||||
run_task.cancel()
|
||||
try:
|
||||
await run_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
mock_dispatch.assert_not_called()
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "nanobot" in out.content.lower() or "Model" in out.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_propagates_external_cancellation(self):
|
||||
"""External task cancellation should not be swallowed by the inbound wait loop."""
|
||||
loop, _bus = _make_loop()
|
||||
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
await asyncio.sleep(0.1)
|
||||
run_task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await asyncio.wait_for(run_task, timeout=1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_includes_restart(self):
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/help")
|
||||
|
||||
response = await loop._process_message(msg)
|
||||
|
||||
assert response is not None
|
||||
assert "/restart" in response.content
|
||||
assert "/status" in response.content
|
||||
assert response.metadata == {"render_as": "text"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_reports_runtime_info(self):
|
||||
loop, _bus = _make_loop()
|
||||
session = MagicMock()
|
||||
session.get_history.return_value = [{"role": "user"}] * 3
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
loop._start_time = time.time() - 125
|
||||
loop._last_usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||
return_value=(20500, "tiktoken")
|
||||
)
|
||||
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
|
||||
|
||||
response = await loop._process_message(msg)
|
||||
|
||||
assert response is not None
|
||||
assert "Model: test-model" in response.content
|
||||
assert "Tokens: 0 in / 0 out" in response.content
|
||||
assert "Context: 20k/64k (31%)" in response.content
|
||||
assert "Session: 3 messages" in response.content
|
||||
assert "Uptime: 2m 5s" in response.content
|
||||
assert response.metadata == {"render_as": "text"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_loop_resets_usage_when_provider_omits_it(self):
|
||||
loop, _bus = _make_loop()
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=[
|
||||
LLMResponse(content="first", usage={"prompt_tokens": 9, "completion_tokens": 4}),
|
||||
LLMResponse(content="second", usage={}),
|
||||
])
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4}
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self):
|
||||
loop, _bus = _make_loop()
|
||||
session = MagicMock()
|
||||
session.get_history.return_value = [{"role": "user"}]
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
loop._last_usage = {"prompt_tokens": 1200, "completion_tokens": 34}
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||
return_value=(0, "none")
|
||||
)
|
||||
|
||||
response = await loop._process_message(
|
||||
InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert "Tokens: 1200 in / 34 out" in response.content
|
||||
assert "Context: 1k/64k (1%)" in response.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_preserves_render_metadata(self):
|
||||
loop, _bus = _make_loop()
|
||||
session = MagicMock()
|
||||
session.get_history.return_value = []
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
loop.subagents.get_running_count.return_value = 0
|
||||
|
||||
response = await loop.process_direct("/status", session_key="cli:test")
|
||||
|
||||
assert response is not None
|
||||
assert response.metadata == {"render_as": "text"}
|
||||
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
|
||||
|
||||
def test_load_config_keeps_max_tokens_and_ignores_legacy_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 1234,
|
||||
"memoryWindow": 42,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
assert config.agents.defaults.max_tokens == 1234
|
||||
assert config.agents.defaults.context_window_tokens == 65_536
|
||||
assert not hasattr(config.agents.defaults, "memory_window")
|
||||
|
||||
|
||||
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 2222,
|
||||
"memoryWindow": 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
|
||||
assert defaults["maxTokens"] == 2222
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
|
||||
|
||||
def test_onboard_does_not_crash_with_legacy_memory_window(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 3333,
|
||||
"memoryWindow": 50,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
|
||||
|
||||
from typer.testing import CliRunner
|
||||
from nanobot.cli.commands import app
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None:
|
||||
from types import SimpleNamespace
|
||||
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"channels": {
|
||||
"qq": {
|
||||
"enabled": False,
|
||||
"appId": "",
|
||||
"secret": "",
|
||||
"allowFrom": [],
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {
|
||||
"qq": SimpleNamespace(
|
||||
default_config=lambda: {
|
||||
"enabled": False,
|
||||
"appId": "",
|
||||
"secret": "",
|
||||
"allowFrom": [],
|
||||
"msgFormat": "plain",
|
||||
}
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
from typer.testing import CliRunner
|
||||
from nanobot.cli.commands import app
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["channels"]["qq"]["msgFormat"] == "plain"
|
||||
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.paths import (
|
||||
get_bridge_install_dir,
|
||||
get_cli_history_path,
|
||||
get_cron_dir,
|
||||
get_data_dir,
|
||||
get_legacy_sessions_dir,
|
||||
get_logs_dir,
|
||||
get_media_dir,
|
||||
get_runtime_subdir,
|
||||
get_workspace_path,
|
||||
is_default_workspace,
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_dirs_follow_config_path(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance-a" / "config.json"
|
||||
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
|
||||
|
||||
assert get_data_dir() == config_file.parent
|
||||
assert get_runtime_subdir("cron") == config_file.parent / "cron"
|
||||
assert get_cron_dir() == config_file.parent / "cron"
|
||||
assert get_logs_dir() == config_file.parent / "logs"
|
||||
|
||||
|
||||
def test_media_dir_supports_channel_namespace(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance-b" / "config.json"
|
||||
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
|
||||
|
||||
assert get_media_dir() == config_file.parent / "media"
|
||||
assert get_media_dir("telegram") == config_file.parent / "media" / "telegram"
|
||||
|
||||
|
||||
def test_shared_and_legacy_paths_remain_global() -> None:
|
||||
assert get_cli_history_path() == Path.home() / ".nanobot" / "history" / "cli_history"
|
||||
assert get_bridge_install_dir() == Path.home() / ".nanobot" / "bridge"
|
||||
assert get_legacy_sessions_dir() == Path.home() / ".nanobot" / "sessions"
|
||||
|
||||
|
||||
def test_workspace_path_is_explicitly_resolved() -> None:
|
||||
assert get_workspace_path() == Path.home() / ".nanobot" / "workspace"
|
||||
assert get_workspace_path("~/custom-workspace") == Path.home() / "custom-workspace"
|
||||
|
||||
|
||||
def test_is_default_workspace_distinguishes_default_and_custom_paths() -> None:
|
||||
assert is_default_workspace(None) is True
|
||||
assert is_default_workspace(Path.home() / ".nanobot" / "workspace") is True
|
||||
assert is_default_workspace("~/custom-workspace") is False
|
||||
@@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronSchedule
|
||||
|
||||
|
||||
def test_add_job_rejects_unknown_timezone(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
|
||||
with pytest.raises(ValueError, match="unknown timezone 'America/Vancovuer'"):
|
||||
service.add_job(
|
||||
name="tz typo",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancovuer"),
|
||||
message="hello",
|
||||
)
|
||||
|
||||
assert service.list_jobs(include_disabled=True) == []
|
||||
|
||||
|
||||
def test_add_job_accepts_valid_timezone(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
|
||||
job = service.add_job(
|
||||
name="tz ok",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancouver"),
|
||||
message="hello",
|
||||
)
|
||||
|
||||
assert job.schedule.tz == "America/Vancouver"
|
||||
assert job.state.next_run_at_ms is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_job_records_run_history(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="hist",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
await service.run_job(job.id)
|
||||
|
||||
loaded = service.get_job(job.id)
|
||||
assert loaded is not None
|
||||
assert len(loaded.state.run_history) == 1
|
||||
rec = loaded.state.run_history[0]
|
||||
assert rec.status == "ok"
|
||||
assert rec.duration_ms >= 0
|
||||
assert rec.error is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_history_records_errors(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
|
||||
async def fail(_):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
service = CronService(store_path, on_job=fail)
|
||||
job = service.add_job(
|
||||
name="fail",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
await service.run_job(job.id)
|
||||
|
||||
loaded = service.get_job(job.id)
|
||||
assert len(loaded.state.run_history) == 1
|
||||
assert loaded.state.run_history[0].status == "error"
|
||||
assert loaded.state.run_history[0].error == "boom"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_history_trimmed_to_max(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="trim",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
for _ in range(25):
|
||||
await service.run_job(job.id)
|
||||
|
||||
loaded = service.get_job(job.id)
|
||||
assert len(loaded.state.run_history) == CronService._MAX_RUN_HISTORY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_history_persisted_to_disk(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="persist",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
await service.run_job(job.id)
|
||||
|
||||
raw = json.loads(store_path.read_text())
|
||||
history = raw["jobs"][0]["state"]["runHistory"]
|
||||
assert len(history) == 1
|
||||
assert history[0]["status"] == "ok"
|
||||
assert "runAtMs" in history[0]
|
||||
assert "durationMs" in history[0]
|
||||
|
||||
fresh = CronService(store_path)
|
||||
loaded = fresh.get_job(job.id)
|
||||
assert len(loaded.state.run_history) == 1
|
||||
assert loaded.state.run_history[0].status == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_service_honors_external_disable(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
called: list[str] = []
|
||||
|
||||
async def on_job(job) -> None:
|
||||
called.append(job.id)
|
||||
|
||||
service = CronService(store_path, on_job=on_job)
|
||||
job = service.add_job(
|
||||
name="external-disable",
|
||||
schedule=CronSchedule(kind="every", every_ms=200),
|
||||
message="hello",
|
||||
)
|
||||
await service.start()
|
||||
try:
|
||||
# Wait slightly to ensure file mtime is definitively different
|
||||
await asyncio.sleep(0.05)
|
||||
external = CronService(store_path)
|
||||
updated = external.enable_job(job.id, enabled=False)
|
||||
assert updated is not None
|
||||
assert updated.enabled is False
|
||||
|
||||
await asyncio.sleep(0.35)
|
||||
assert called == []
|
||||
finally:
|
||||
service.stop()
|
||||
@@ -0,0 +1,299 @@
|
||||
"""Tests for CronTool._list_jobs() output formatting."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJobState, CronSchedule
|
||||
|
||||
|
||||
def _make_tool(tmp_path) -> CronTool:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
return CronTool(service)
|
||||
|
||||
|
||||
def _make_tool_with_tz(tmp_path, tz: str) -> CronTool:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
return CronTool(service, default_timezone=tz)
|
||||
|
||||
|
||||
# -- _format_timing tests --
|
||||
|
||||
|
||||
def test_format_timing_cron_with_tz(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
s = CronSchedule(kind="cron", expr="0 9 * * 1-5", tz="America/Denver")
|
||||
assert tool._format_timing(s) == "cron: 0 9 * * 1-5 (America/Denver)"
|
||||
|
||||
|
||||
def test_format_timing_cron_without_tz(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
s = CronSchedule(kind="cron", expr="*/5 * * * *")
|
||||
assert tool._format_timing(s) == "cron: */5 * * * *"
|
||||
|
||||
|
||||
def test_format_timing_every_hours(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
s = CronSchedule(kind="every", every_ms=7_200_000)
|
||||
assert tool._format_timing(s) == "every 2h"
|
||||
|
||||
|
||||
def test_format_timing_every_minutes(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
s = CronSchedule(kind="every", every_ms=1_800_000)
|
||||
assert tool._format_timing(s) == "every 30m"
|
||||
|
||||
|
||||
def test_format_timing_every_seconds(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
s = CronSchedule(kind="every", every_ms=30_000)
|
||||
assert tool._format_timing(s) == "every 30s"
|
||||
|
||||
|
||||
def test_format_timing_every_non_minute_seconds(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
s = CronSchedule(kind="every", every_ms=90_000)
|
||||
assert tool._format_timing(s) == "every 90s"
|
||||
|
||||
|
||||
def test_format_timing_every_milliseconds(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
s = CronSchedule(kind="every", every_ms=200)
|
||||
assert tool._format_timing(s) == "every 200ms"
|
||||
|
||||
|
||||
def test_format_timing_at(tmp_path) -> None:
|
||||
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
|
||||
s = CronSchedule(kind="at", at_ms=1773684000000)
|
||||
result = tool._format_timing(s)
|
||||
assert "Asia/Shanghai" in result
|
||||
assert result.startswith("at 2026-")
|
||||
|
||||
|
||||
def test_format_timing_fallback(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
s = CronSchedule(kind="every") # no every_ms
|
||||
assert tool._format_timing(s) == "every"
|
||||
|
||||
|
||||
# -- _format_state tests --
|
||||
|
||||
|
||||
def test_format_state_empty(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
state = CronJobState()
|
||||
assert tool._format_state(state, CronSchedule(kind="every")) == []
|
||||
|
||||
|
||||
def test_format_state_last_run_ok(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
state = CronJobState(last_run_at_ms=1773673200000, last_status="ok")
|
||||
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
|
||||
assert len(lines) == 1
|
||||
assert "Last run:" in lines[0]
|
||||
assert "ok" in lines[0]
|
||||
|
||||
|
||||
def test_format_state_last_run_with_error(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
state = CronJobState(last_run_at_ms=1773673200000, last_status="error", last_error="timeout")
|
||||
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
|
||||
assert len(lines) == 1
|
||||
assert "error" in lines[0]
|
||||
assert "timeout" in lines[0]
|
||||
|
||||
|
||||
def test_format_state_next_run_only(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
state = CronJobState(next_run_at_ms=1773684000000)
|
||||
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
|
||||
assert len(lines) == 1
|
||||
assert "Next run:" in lines[0]
|
||||
|
||||
|
||||
def test_format_state_both(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
state = CronJobState(
|
||||
last_run_at_ms=1773673200000, last_status="ok", next_run_at_ms=1773684000000
|
||||
)
|
||||
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
|
||||
assert len(lines) == 2
|
||||
assert "Last run:" in lines[0]
|
||||
assert "Next run:" in lines[1]
|
||||
|
||||
|
||||
def test_format_state_unknown_status(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
state = CronJobState(last_run_at_ms=1773673200000, last_status=None)
|
||||
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
|
||||
assert "unknown" in lines[0]
|
||||
|
||||
|
||||
# -- _list_jobs integration tests --
|
||||
|
||||
|
||||
def test_list_empty(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
assert tool._list_jobs() == "No scheduled jobs."
|
||||
|
||||
|
||||
def test_list_cron_job_shows_expression_and_timezone(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron.add_job(
|
||||
name="Morning scan",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * 1-5", tz="America/Denver"),
|
||||
message="scan",
|
||||
)
|
||||
result = tool._list_jobs()
|
||||
assert "cron: 0 9 * * 1-5 (America/Denver)" in result
|
||||
|
||||
|
||||
def test_list_every_job_shows_human_interval(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron.add_job(
|
||||
name="Frequent check",
|
||||
schedule=CronSchedule(kind="every", every_ms=1_800_000),
|
||||
message="check",
|
||||
)
|
||||
result = tool._list_jobs()
|
||||
assert "every 30m" in result
|
||||
|
||||
|
||||
def test_list_every_job_hours(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron.add_job(
|
||||
name="Hourly check",
|
||||
schedule=CronSchedule(kind="every", every_ms=7_200_000),
|
||||
message="check",
|
||||
)
|
||||
result = tool._list_jobs()
|
||||
assert "every 2h" in result
|
||||
|
||||
|
||||
def test_list_every_job_seconds(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron.add_job(
|
||||
name="Fast check",
|
||||
schedule=CronSchedule(kind="every", every_ms=30_000),
|
||||
message="check",
|
||||
)
|
||||
result = tool._list_jobs()
|
||||
assert "every 30s" in result
|
||||
|
||||
|
||||
def test_list_every_job_non_minute_seconds(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron.add_job(
|
||||
name="Ninety-second check",
|
||||
schedule=CronSchedule(kind="every", every_ms=90_000),
|
||||
message="check",
|
||||
)
|
||||
result = tool._list_jobs()
|
||||
assert "every 90s" in result
|
||||
|
||||
|
||||
def test_list_every_job_milliseconds(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron.add_job(
|
||||
name="Sub-second check",
|
||||
schedule=CronSchedule(kind="every", every_ms=200),
|
||||
message="check",
|
||||
)
|
||||
result = tool._list_jobs()
|
||||
assert "every 200ms" in result
|
||||
|
||||
|
||||
def test_list_at_job_shows_iso_timestamp(tmp_path) -> None:
|
||||
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
|
||||
tool._cron.add_job(
|
||||
name="One-shot",
|
||||
schedule=CronSchedule(kind="at", at_ms=1773684000000),
|
||||
message="fire",
|
||||
)
|
||||
result = tool._list_jobs()
|
||||
assert "at 2026-" in result
|
||||
assert "Asia/Shanghai" in result
|
||||
|
||||
|
||||
def test_list_shows_last_run_state(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
job = tool._cron.add_job(
|
||||
name="Stateful job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
message="test",
|
||||
)
|
||||
# Simulate a completed run by updating state in the store
|
||||
job.state.last_run_at_ms = 1773673200000
|
||||
job.state.last_status = "ok"
|
||||
tool._cron._save_store()
|
||||
|
||||
result = tool._list_jobs()
|
||||
assert "Last run:" in result
|
||||
assert "ok" in result
|
||||
assert "(UTC)" in result
|
||||
|
||||
|
||||
def test_list_shows_error_message(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
job = tool._cron.add_job(
|
||||
name="Failed job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
message="test",
|
||||
)
|
||||
job.state.last_run_at_ms = 1773673200000
|
||||
job.state.last_status = "error"
|
||||
job.state.last_error = "timeout"
|
||||
tool._cron._save_store()
|
||||
|
||||
result = tool._list_jobs()
|
||||
assert "error" in result
|
||||
assert "timeout" in result
|
||||
|
||||
|
||||
def test_list_shows_next_run(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron.add_job(
|
||||
name="Upcoming job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
message="test",
|
||||
)
|
||||
result = tool._list_jobs()
|
||||
assert "Next run:" in result
|
||||
assert "(UTC)" in result
|
||||
|
||||
|
||||
def test_add_cron_job_defaults_to_tool_timezone(tmp_path) -> None:
|
||||
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Morning standup", None, "0 8 * * *", None, None)
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
assert job.schedule.tz == "Asia/Shanghai"
|
||||
|
||||
|
||||
def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None:
|
||||
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Morning reminder", None, None, None, "2026-03-25T08:00:00")
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
expected = int(datetime(2026, 3, 25, 0, 0, 0, tzinfo=timezone.utc).timestamp() * 1000)
|
||||
assert job.schedule.at_ms == expected
|
||||
|
||||
|
||||
def test_list_excludes_disabled_jobs(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
job = tool._cron.add_job(
|
||||
name="Paused job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
message="test",
|
||||
)
|
||||
tool._cron.enable_job(job.id, enabled=False)
|
||||
|
||||
result = tool._list_jobs()
|
||||
assert "Paused job" not in result
|
||||
assert result == "No scheduled jobs."
|
||||
@@ -0,0 +1,399 @@
|
||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def test_azure_openai_provider_init():
|
||||
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||
assert provider.default_model == "gpt-4o-deployment"
|
||||
assert provider.api_version == "2024-10-21"
|
||||
|
||||
|
||||
def test_azure_openai_provider_init_validation():
|
||||
"""Test AzureOpenAIProvider initialization validation."""
|
||||
# Missing api_key
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||
|
||||
# Missing api_base
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||
AzureOpenAIProvider(api_key="test", api_base="")
|
||||
|
||||
|
||||
def test_build_chat_url():
|
||||
"""Test Azure OpenAI URL building with different deployment names."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test various deployment names
|
||||
test_cases = [
|
||||
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||
]
|
||||
|
||||
for deployment_name, expected_url in test_cases:
|
||||
url = provider._build_chat_url(deployment_name)
|
||||
assert url == expected_url
|
||||
|
||||
|
||||
def test_build_chat_url_api_base_without_slash():
|
||||
"""Test URL building when api_base doesn't end with slash."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
url = provider._build_chat_url("test-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
|
||||
|
||||
def test_build_headers():
|
||||
"""Test Azure OpenAI header building with api-key authentication."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-api-key-123",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
headers = provider._build_headers()
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||
assert "x-session-affinity" in headers
|
||||
|
||||
|
||||
def test_prepare_request_payload():
|
||||
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||
|
||||
assert payload["messages"] == messages
|
||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||
assert payload["temperature"] == 0.8
|
||||
assert "tools" not in payload
|
||||
|
||||
# Test with tools
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||
assert payload_with_tools["tools"] == tools
|
||||
assert payload_with_tools["tool_choice"] == "auto"
|
||||
|
||||
# Test with reasoning_effort
|
||||
payload_with_reasoning = provider._prepare_request_payload(
|
||||
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||
)
|
||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||
assert "temperature" not in payload_with_reasoning
|
||||
|
||||
|
||||
def test_prepare_request_payload_sanitizes_messages():
|
||||
"""Test Azure payload strips non-standard message keys before sending."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
"reasoning_content": "hidden chain-of-thought",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
"extra_field": "should be removed",
|
||||
},
|
||||
]
|
||||
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||
|
||||
assert payload["messages"] == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_success():
|
||||
"""Test successful chat request using model as deployment name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 18,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
# Test with specific model (deployment name)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages, model="custom-deployment")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 12
|
||||
assert result.usage["completion_tokens"] == 18
|
||||
assert result.usage["total_tokens"] == 30
|
||||
|
||||
# Verify URL was built with the provided model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_uses_default_model_when_no_model_provided():
|
||||
"""Test that chat uses default_model when no model is specified."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="default-deployment",
|
||||
)
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {"content": "Response", "role": "assistant"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
await provider.chat(messages) # No model specified
|
||||
|
||||
# Verify URL was built with default model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_tool_calls():
|
||||
"""Test chat request with tool calls in response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response with tool calls
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_12345",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}'
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content is None
|
||||
assert result.finish_reason == "tool_calls"
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_api_error():
|
||||
"""Test chat request API error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Invalid authentication credentials"
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Azure OpenAI API Error 401" in result.content
|
||||
assert "Invalid authentication credentials" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_connection_error():
|
||||
"""Test chat request connection error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_parse_response_malformed():
|
||||
"""Test response parsing with malformed data."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test with missing choices
|
||||
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||
result = provider._parse_response(malformed_response)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error parsing Azure OpenAI response" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_get_default_model():
|
||||
"""Test get_default_model method."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="my-custom-deployment",
|
||||
)
|
||||
|
||||
assert provider.get_default_model() == "my-custom-deployment"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Running basic Azure OpenAI provider tests...")
|
||||
|
||||
# Test initialization
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
print("✅ Provider initialization successful")
|
||||
|
||||
# Test URL building
|
||||
url = provider._build_chat_url("my-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
print("✅ URL building works correctly")
|
||||
|
||||
# Test headers
|
||||
headers = provider._build_headers()
|
||||
assert headers["api-key"] == "test-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
print("✅ Header building works correctly")
|
||||
|
||||
# Test payload preparation
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||
print("✅ Payload preparation works correctly")
|
||||
|
||||
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Tests for OpenAICompatProvider handling custom/direct endpoints."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
def test_custom_provider_parse_handles_empty_choices() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
response = SimpleNamespace(choices=[])
|
||||
|
||||
result = provider._parse(response)
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
assert "empty choices" in result.content
|
||||
|
||||
|
||||
def test_custom_provider_parse_accepts_plain_string_response() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
result = provider._parse("hello from backend")
|
||||
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "hello from backend"
|
||||
|
||||
|
||||
def test_custom_provider_parse_accepts_dict_response() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
result = provider._parse({
|
||||
"choices": [{
|
||||
"message": {"content": "hello from dict"},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
})
|
||||
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "hello from dict"
|
||||
assert result.usage["total_tokens"] == 3
|
||||
|
||||
|
||||
def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
|
||||
result = OpenAICompatProvider._parse_chunks(["hello ", "world"])
|
||||
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "hello world"
|
||||
@@ -0,0 +1,216 @@
|
||||
"""Tests for OpenAICompatProvider spec-driven behavior.
|
||||
|
||||
Validates that:
|
||||
- OpenRouter (no strip) keeps model names intact.
|
||||
- AiHubMix (strip_model_prefix=True) strips provider prefixes.
|
||||
- Standard providers pass model names through as-is.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
|
||||
def _fake_chat_response(content: str = "ok") -> SimpleNamespace:
|
||||
"""Build a minimal OpenAI chat completion response."""
|
||||
message = SimpleNamespace(
|
||||
content=content,
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
)
|
||||
choice = SimpleNamespace(message=message, finish_reason="stop")
|
||||
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def _fake_tool_call_response() -> SimpleNamespace:
|
||||
"""Build a minimal chat response that includes Gemini-style extra_content."""
|
||||
function = SimpleNamespace(
|
||||
name="exec",
|
||||
arguments='{"cmd":"ls"}',
|
||||
provider_specific_fields={"inner": "value"},
|
||||
)
|
||||
tool_call = SimpleNamespace(
|
||||
id="call_123",
|
||||
index=0,
|
||||
type="function",
|
||||
function=function,
|
||||
extra_content={"google": {"thought_signature": "signed-token"}},
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
content=None,
|
||||
tool_calls=[tool_call],
|
||||
reasoning_content=None,
|
||||
)
|
||||
choice = SimpleNamespace(message=message, finish_reason="tool_calls")
|
||||
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def test_openrouter_spec_is_gateway() -> None:
|
||||
spec = find_by_name("openrouter")
|
||||
assert spec is not None
|
||||
assert spec.is_gateway is True
|
||||
assert spec.default_api_base == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_openrouter_sets_default_attribution_headers() -> None:
|
||||
spec = find_by_name("openrouter")
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
OpenAICompatProvider(
|
||||
api_key="sk-or-test-key",
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
default_model="anthropic/claude-sonnet-4-5",
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
headers = MockClient.call_args.kwargs["default_headers"]
|
||||
assert headers["HTTP-Referer"] == "https://github.com/HKUDS/nanobot"
|
||||
assert headers["X-OpenRouter-Title"] == "nanobot"
|
||||
assert headers["X-OpenRouter-Categories"] == "cli-agent,personal-agent"
|
||||
assert "x-session-affinity" in headers
|
||||
|
||||
|
||||
def test_openrouter_user_headers_override_default_attribution() -> None:
|
||||
spec = find_by_name("openrouter")
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
OpenAICompatProvider(
|
||||
api_key="sk-or-test-key",
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
default_model="anthropic/claude-sonnet-4-5",
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://nanobot.ai",
|
||||
"X-OpenRouter-Title": "Nanobot Pro",
|
||||
"X-Custom-App": "enabled",
|
||||
},
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
headers = MockClient.call_args.kwargs["default_headers"]
|
||||
assert headers["HTTP-Referer"] == "https://nanobot.ai"
|
||||
assert headers["X-OpenRouter-Title"] == "Nanobot Pro"
|
||||
assert headers["X-OpenRouter-Categories"] == "cli-agent,personal-agent"
|
||||
assert headers["X-Custom-App"] == "enabled"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openrouter_keeps_model_name_intact() -> None:
|
||||
"""OpenRouter gateway keeps the full model name (gateway does its own routing)."""
|
||||
mock_create = AsyncMock(return_value=_fake_chat_response())
|
||||
spec = find_by_name("openrouter")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-or-test-key",
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
default_model="anthropic/claude-sonnet-4-5",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="anthropic/claude-sonnet-4-5",
|
||||
)
|
||||
|
||||
call_kwargs = mock_create.call_args.kwargs
|
||||
assert call_kwargs["model"] == "anthropic/claude-sonnet-4-5"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aihubmix_strips_model_prefix() -> None:
|
||||
"""AiHubMix strips the provider prefix (strip_model_prefix=True)."""
|
||||
mock_create = AsyncMock(return_value=_fake_chat_response())
|
||||
spec = find_by_name("aihubmix")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-aihub-test-key",
|
||||
api_base="https://aihubmix.com/v1",
|
||||
default_model="claude-sonnet-4-5",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="anthropic/claude-sonnet-4-5",
|
||||
)
|
||||
|
||||
call_kwargs = mock_create.call_args.kwargs
|
||||
assert call_kwargs["model"] == "claude-sonnet-4-5"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standard_provider_passes_model_through() -> None:
|
||||
"""Standard provider (e.g. deepseek) passes model name through as-is."""
|
||||
mock_create = AsyncMock(return_value=_fake_chat_response())
|
||||
spec = find_by_name("deepseek")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-deepseek-test-key",
|
||||
default_model="deepseek-chat",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="deepseek-chat",
|
||||
)
|
||||
|
||||
call_kwargs = mock_create.call_args.kwargs
|
||||
assert call_kwargs["model"] == "deepseek-chat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None:
|
||||
"""Gemini extra_content (thought signatures) must survive parse→serialize round-trip."""
|
||||
mock_create = AsyncMock(return_value=_fake_tool_call_response())
|
||||
spec = find_by_name("gemini")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
default_model="google/gemini-3.1-pro-preview",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat(
|
||||
messages=[{"role": "user", "content": "run exec"}],
|
||||
model="google/gemini-3.1-pro-preview",
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}}
|
||||
assert tool_call.function_provider_specific_fields == {"inner": "value"}
|
||||
|
||||
serialized = tool_call.to_openai_tool_call()
|
||||
assert serialized["extra_content"] == {"google": {"thought_signature": "signed-token"}}
|
||||
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||
|
||||
|
||||
def test_openai_model_passthrough() -> None:
|
||||
"""OpenAI models pass through unchanged."""
|
||||
spec = find_by_name("openai")
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-4o",
|
||||
spec=spec,
|
||||
)
|
||||
assert provider.get_default_model() == "gpt-4o"
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Tests for the Mistral provider registration."""
|
||||
|
||||
from nanobot.config.schema import ProvidersConfig
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
|
||||
def test_mistral_config_field_exists():
|
||||
"""ProvidersConfig should have a mistral field."""
|
||||
config = ProvidersConfig()
|
||||
assert hasattr(config, "mistral")
|
||||
|
||||
|
||||
def test_mistral_provider_in_registry():
|
||||
"""Mistral should be registered in the provider registry."""
|
||||
specs = {s.name: s for s in PROVIDERS}
|
||||
assert "mistral" in specs
|
||||
|
||||
mistral = specs["mistral"]
|
||||
assert mistral.env_key == "MISTRAL_API_KEY"
|
||||
assert mistral.default_api_base == "https://api.mistral.ai/v1"
|
||||
@@ -0,0 +1,213 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||
|
||||
|
||||
class ScriptedProvider(LLMProvider):
|
||||
def __init__(self, responses):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
self.last_kwargs: dict = {}
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
self.last_kwargs = kwargs
|
||||
response = self._responses.pop(0)
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
return response
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.finish_reason == "stop"
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="401 unauthorized", finish_reason="error"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "401 unauthorized"
|
||||
assert provider.calls == 1
|
||||
assert delays == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit a", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit b", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit c", finish_reason="error"),
|
||||
LLMResponse(content="503 final server error", finish_reason="error"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "503 final server error"
|
||||
assert provider.calls == 4
|
||||
assert delays == [1, 2, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
||||
provider = ScriptedProvider([asyncio.CancelledError()])
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||
"""When callers omit generation params, provider.generation defaults are used."""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||
|
||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert provider.last_kwargs["temperature"] == 0.2
|
||||
assert provider.last_kwargs["max_tokens"] == 321
|
||||
assert provider.last_kwargs["reasoning_effort"] == "high"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
|
||||
"""Explicit kwargs should override provider.generation defaults."""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||
|
||||
await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.9,
|
||||
max_tokens=9999,
|
||||
reasoning_effort="low",
|
||||
)
|
||||
|
||||
assert provider.last_kwargs["temperature"] == 0.9
|
||||
assert provider.last_kwargs["max_tokens"] == 9999
|
||||
assert provider.last_kwargs["reasoning_effort"] == "low"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image fallback tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMAGE_MSG = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}, "_meta": {"path": "/media/test.png"}},
|
||||
]},
|
||||
]
|
||||
|
||||
_IMAGE_MSG_NO_META = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
]},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_transient_error_with_images_retries_without_images() -> None:
|
||||
"""Any non-transient error retries once with images stripped when images are present."""
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="API调用参数有误,请检查文档", finish_reason="error"),
|
||||
LLMResponse(content="ok, no image"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||
|
||||
assert response.content == "ok, no image"
|
||||
assert provider.calls == 2
|
||||
msgs_on_retry = provider.last_kwargs["messages"]
|
||||
for msg in msgs_on_retry:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
assert all(b.get("type") != "image_url" for b in content)
|
||||
assert any("[image: /media/test.png]" in (b.get("text") or "") for b in content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_transient_error_without_images_no_retry() -> None:
|
||||
"""Non-transient errors without image content are returned immediately."""
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="401 unauthorized", finish_reason="error"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
assert provider.calls == 1
|
||||
assert response.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_fallback_returns_error_on_second_failure() -> None:
|
||||
"""If the image-stripped retry also fails, return that error."""
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="some model error", finish_reason="error"),
|
||||
LLMResponse(content="still failing", finish_reason="error"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||
|
||||
assert provider.calls == 2
|
||||
assert response.content == "still failing"
|
||||
assert response.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
|
||||
"""When _meta is absent, fallback placeholder is '[image omitted]'."""
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="error", finish_reason="error"),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG_NO_META)
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
msgs_on_retry = provider.last_kwargs["messages"]
|
||||
for msg in msgs_on_retry:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
assert any("[image omitted]" in (b.get("text") or "") for b in content)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Tests for lazy provider exports from nanobot.providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
|
||||
def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
|
||||
|
||||
providers = importlib.import_module("nanobot.providers")
|
||||
|
||||
assert "nanobot.providers.anthropic_provider" not in sys.modules
|
||||
assert "nanobot.providers.openai_compat_provider" not in sys.modules
|
||||
assert "nanobot.providers.openai_codex_provider" not in sys.modules
|
||||
assert "nanobot.providers.azure_openai_provider" not in sys.modules
|
||||
assert providers.__all__ == [
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"AnthropicProvider",
|
||||
"OpenAICompatProvider",
|
||||
"OpenAICodexProvider",
|
||||
"AzureOpenAIProvider",
|
||||
]
|
||||
|
||||
|
||||
def test_explicit_provider_import_still_works(monkeypatch) -> None:
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
|
||||
|
||||
namespace: dict[str, object] = {}
|
||||
exec("from nanobot.providers import AnthropicProvider", namespace)
|
||||
|
||||
assert namespace["AnthropicProvider"].__name__ == "AnthropicProvider"
|
||||
assert "nanobot.providers.anthropic_provider" in sys.modules
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Tests for nanobot.security.network — SSRF protection and internal URL detection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.security.network import contains_internal_url, validate_url_target
|
||||
|
||||
|
||||
def _fake_resolve(host: str, results: list[str]):
|
||||
"""Return a getaddrinfo mock that maps the given host to fake IP results."""
|
||||
def _resolver(hostname, port, family=0, type_=0):
|
||||
if hostname == host:
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (ip, 0)) for ip in results]
|
||||
raise socket.gaierror(f"cannot resolve {hostname}")
|
||||
return _resolver
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_url_target — scheme / domain basics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_rejects_non_http_scheme():
|
||||
ok, err = validate_url_target("ftp://example.com/file")
|
||||
assert not ok
|
||||
assert "http" in err.lower()
|
||||
|
||||
|
||||
def test_rejects_missing_domain():
|
||||
ok, err = validate_url_target("http://")
|
||||
assert not ok
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_url_target — blocked private/internal IPs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("ip,label", [
|
||||
("127.0.0.1", "loopback"),
|
||||
("127.0.0.2", "loopback_alt"),
|
||||
("10.0.0.1", "rfc1918_10"),
|
||||
("172.16.5.1", "rfc1918_172"),
|
||||
("192.168.1.1", "rfc1918_192"),
|
||||
("169.254.169.254", "metadata"),
|
||||
("0.0.0.0", "zero"),
|
||||
])
|
||||
def test_blocks_private_ipv4(ip: str, label: str):
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("evil.com", [ip])):
|
||||
ok, err = validate_url_target(f"http://evil.com/path")
|
||||
assert not ok, f"Should block {label} ({ip})"
|
||||
assert "private" in err.lower() or "blocked" in err.lower()
|
||||
|
||||
|
||||
def test_blocks_ipv6_loopback():
|
||||
def _resolver(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("::1", 0, 0, 0))]
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _resolver):
|
||||
ok, err = validate_url_target("http://evil.com/")
|
||||
assert not ok
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_url_target — allows public IPs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_allows_public_ip():
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("example.com", ["93.184.216.34"])):
|
||||
ok, err = validate_url_target("http://example.com/page")
|
||||
assert ok, f"Should allow public IP, got: {err}"
|
||||
|
||||
|
||||
def test_allows_normal_https():
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("github.com", ["140.82.121.3"])):
|
||||
ok, err = validate_url_target("https://github.com/HKUDS/nanobot")
|
||||
assert ok
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# contains_internal_url — shell command scanning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_detects_curl_metadata():
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("169.254.169.254", ["169.254.169.254"])):
|
||||
assert contains_internal_url('curl -s http://169.254.169.254/computeMetadata/v1/')
|
||||
|
||||
|
||||
def test_detects_wget_localhost():
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("localhost", ["127.0.0.1"])):
|
||||
assert contains_internal_url("wget http://localhost:8080/secret")
|
||||
|
||||
|
||||
def test_allows_normal_curl():
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("example.com", ["93.184.216.34"])):
|
||||
assert not contains_internal_url("curl https://example.com/api/data")
|
||||
|
||||
|
||||
def test_no_urls_returns_false():
|
||||
assert not contains_internal_url("echo hello && ls -la")
|
||||
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd "$(dirname "$0")/.." || exit 1
|
||||
|
||||
IMAGE_NAME="nanobot-test"
|
||||
|
||||
echo "=== Building Docker image ==="
|
||||
docker build -t "$IMAGE_NAME" .
|
||||
|
||||
echo ""
|
||||
echo "=== Running 'nanobot onboard' ==="
|
||||
docker run --name nanobot-test-run "$IMAGE_NAME" onboard
|
||||
|
||||
echo ""
|
||||
echo "=== Running 'nanobot status' ==="
|
||||
STATUS_OUTPUT=$(docker commit nanobot-test-run nanobot-test-onboarded > /dev/null && \
|
||||
docker run --rm nanobot-test-onboarded status 2>&1) || true
|
||||
|
||||
echo "$STATUS_OUTPUT"
|
||||
|
||||
echo ""
|
||||
echo "=== Validating output ==="
|
||||
PASS=true
|
||||
|
||||
check() {
|
||||
if echo "$STATUS_OUTPUT" | grep -q "$1"; then
|
||||
echo " PASS: found '$1'"
|
||||
else
|
||||
echo " FAIL: missing '$1'"
|
||||
PASS=false
|
||||
fi
|
||||
}
|
||||
|
||||
check "nanobot Status"
|
||||
check "Config:"
|
||||
check "Workspace:"
|
||||
check "Model:"
|
||||
check "OpenRouter API:"
|
||||
check "Anthropic API:"
|
||||
check "OpenAI API:"
|
||||
|
||||
echo ""
|
||||
if $PASS; then
|
||||
echo "=== All checks passed ==="
|
||||
else
|
||||
echo "=== Some checks FAILED ==="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
echo ""
|
||||
echo "=== Cleanup ==="
|
||||
docker rm -f nanobot-test-run 2>/dev/null || true
|
||||
docker rmi -f nanobot-test-onboarded 2>/dev/null || true
|
||||
docker rmi -f "$IMAGE_NAME" 2>/dev/null || true
|
||||
echo "Done."
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Tests for exec tool internal URL blocking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
|
||||
def _fake_resolve_private(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
|
||||
|
||||
|
||||
def _fake_resolve_localhost(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
|
||||
|
||||
|
||||
def _fake_resolve_public(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_blocks_curl_metadata():
|
||||
tool = ExecTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
|
||||
result = await tool.execute(
|
||||
command='curl -s -H "Metadata-Flavor: Google" http://169.254.169.254/computeMetadata/v1/'
|
||||
)
|
||||
assert "Error" in result
|
||||
assert "internal" in result.lower() or "private" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_blocks_wget_localhost():
|
||||
tool = ExecTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost):
|
||||
result = await tool.execute(command="wget http://localhost:8080/secret -O /tmp/out")
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allows_normal_commands():
|
||||
tool = ExecTool(timeout=5)
|
||||
result = await tool.execute(command="echo hello")
|
||||
assert "hello" in result
|
||||
assert "Error" not in result.split("\n")[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allows_curl_to_public_url():
|
||||
"""Commands with public URLs should not be blocked by the internal URL check."""
|
||||
tool = ExecTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
|
||||
guard_result = tool._guard_command("curl https://example.com/api", "/tmp")
|
||||
assert guard_result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_blocks_chained_internal_url():
|
||||
"""Internal URLs buried in chained commands should still be caught."""
|
||||
tool = ExecTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
|
||||
result = await tool.execute(
|
||||
command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
|
||||
)
|
||||
assert "Error" in result
|
||||
@@ -0,0 +1,394 @@
|
||||
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import (
|
||||
EditFileTool,
|
||||
ListDirTool,
|
||||
ReadFileTool,
|
||||
_find_match,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReadFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def sample_file(self, tmp_path):
|
||||
f = tmp_path / "sample.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
|
||||
return f
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_read_has_line_numbers(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file))
|
||||
assert "1| line 1" in result
|
||||
assert "20| line 20" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_and_limit(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
|
||||
assert "5| line 5" in result
|
||||
assert "7| line 7" in result
|
||||
assert "8| line 8" not in result
|
||||
assert "Use offset=8 to continue" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_beyond_end(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=999)
|
||||
assert "Error" in result
|
||||
assert "beyond end" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_of_file_marker(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
|
||||
assert "End of file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file(self, tool, tmp_path):
|
||||
f = tmp_path / "empty.txt"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert "Empty file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_file_returns_multimodal_blocks(self, tool, tmp_path):
|
||||
f = tmp_path / "pixel.png"
|
||||
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
|
||||
|
||||
result = await tool.execute(path=str(f))
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["type"] == "image_url"
|
||||
assert result[0]["image_url"]["url"].startswith("data:image/png;base64,")
|
||||
assert result[0]["_meta"]["path"] == str(f)
|
||||
assert result[1] == {"type": "text", "text": f"(Image file: {f})"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope.txt"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_path_returns_clear_error(self, tool):
|
||||
result = await tool.execute()
|
||||
assert result == "Error reading file: Unknown path"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_char_budget_trims(self, tool, tmp_path):
|
||||
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
|
||||
f = tmp_path / "big.txt"
|
||||
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
|
||||
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
|
||||
assert "Use offset=" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_match (unit tests for the helper)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindMatch:
|
||||
|
||||
def test_exact_match(self):
|
||||
match, count = _find_match("hello world", "world")
|
||||
assert match == "world"
|
||||
assert count == 1
|
||||
|
||||
def test_exact_no_match(self):
|
||||
match, count = _find_match("hello world", "xyz")
|
||||
assert match is None
|
||||
assert count == 0
|
||||
|
||||
def test_crlf_normalisation(self):
|
||||
# Caller normalises CRLF before calling _find_match, so test with
|
||||
# pre-normalised content to verify exact match still works.
|
||||
content = "line1\nline2\nline3"
|
||||
old_text = "line1\nline2\nline3"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
|
||||
def test_line_trim_fallback(self):
|
||||
content = " def foo():\n pass\n"
|
||||
old_text = "def foo():\n pass"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
# The returned match should be the *original* indented text
|
||||
assert " def foo():" in match
|
||||
|
||||
def test_line_trim_multiple_candidates(self):
|
||||
content = " a\n b\n a\n b\n"
|
||||
old_text = "a\nb"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert count == 2
|
||||
|
||||
def test_empty_old_text(self):
|
||||
match, count = _find_match("hello", "")
|
||||
# Empty string is always "in" any string via exact match
|
||||
assert match == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EditFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "hello earth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crlf_normalisation(self, tool, tmp_path):
|
||||
f = tmp_path / "crlf.py"
|
||||
f.write_bytes(b"line1\r\nline2\r\nline3")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
raw = f.read_bytes()
|
||||
assert b"LINE1" in raw
|
||||
# CRLF line endings should be preserved throughout the file
|
||||
assert b"\r\n" in raw
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_fallback(self, tool, tmp_path):
|
||||
f = tmp_path / "indent.py"
|
||||
f.write_text(" def foo():\n pass\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert "bar" in f.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_match(self, tool, tmp_path):
|
||||
f = tmp_path / "dup.py"
|
||||
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
|
||||
assert "appears" in result.lower() or "Warning" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all(self, tool, tmp_path):
|
||||
f = tmp_path / "multi.py"
|
||||
f.write_text("foo bar foo bar foo", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="foo", new_text="baz", replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "baz bar baz bar baz"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
f = tmp_path / "nf.py"
|
||||
f.write_text("hello", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_new_text_returns_clear_error(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="hello")
|
||||
assert result == "Error editing file: Unknown new_text"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ListDirTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListDirTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ListDirTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def populated_dir(self, tmp_path):
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "main.py").write_text("pass")
|
||||
(tmp_path / "src" / "utils.py").write_text("pass")
|
||||
(tmp_path / "README.md").write_text("hi")
|
||||
(tmp_path / ".git").mkdir()
|
||||
(tmp_path / ".git" / "config").write_text("x")
|
||||
(tmp_path / "node_modules").mkdir()
|
||||
(tmp_path / "node_modules" / "pkg").mkdir()
|
||||
return tmp_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_list(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir))
|
||||
assert "README.md" in result
|
||||
assert "src" in result
|
||||
# .git and node_modules should be ignored
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recursive(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir), recursive=True)
|
||||
# Normalize path separators for cross-platform compatibility
|
||||
normalized = result.replace("\\", "/")
|
||||
assert "src/main.py" in normalized
|
||||
assert "src/utils.py" in normalized
|
||||
assert "README.md" in result
|
||||
# Ignored dirs should not appear
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_entries_truncation(self, tool, tmp_path):
|
||||
for i in range(10):
|
||||
(tmp_path / f"file_{i}.txt").write_text("x")
|
||||
result = await tool.execute(path=str(tmp_path), max_entries=3)
|
||||
assert "truncated" in result
|
||||
assert "3 of 10" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_dir(self, tool, tmp_path):
|
||||
d = tmp_path / "empty"
|
||||
d.mkdir()
|
||||
result = await tool.execute(path=str(d))
|
||||
assert "empty" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_path_returns_clear_error(self, tool):
|
||||
result = await tool.execute()
|
||||
assert result == "Error listing directory: Unknown path"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workspace restriction + extra_allowed_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWorkspaceRestriction:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_blocked_outside_workspace(self, tmp_path):
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
secret = outside / "secret.txt"
|
||||
secret.write_text("top secret")
|
||||
|
||||
tool = ReadFileTool(workspace=workspace, allowed_dir=workspace)
|
||||
result = await tool.execute(path=str(secret))
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_allowed_with_extra_dir(self, tmp_path):
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
skill_file = skills_dir / "test_skill" / "SKILL.md"
|
||||
skill_file.parent.mkdir()
|
||||
skill_file.write_text("# Test Skill\nDo something.")
|
||||
|
||||
tool = ReadFileTool(
|
||||
workspace=workspace, allowed_dir=workspace,
|
||||
extra_allowed_dirs=[skills_dir],
|
||||
)
|
||||
result = await tool.execute(path=str(skill_file))
|
||||
assert "Test Skill" in result
|
||||
assert "Error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_dirs_does_not_widen_write(self, tmp_path):
|
||||
from nanobot.agent.tools.filesystem import WriteFileTool
|
||||
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
|
||||
tool = WriteFileTool(workspace=workspace, allowed_dir=workspace)
|
||||
result = await tool.execute(path=str(outside / "hack.txt"), content="pwned")
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_still_blocked_for_unrelated_dir(self, tmp_path):
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
unrelated = tmp_path / "other"
|
||||
unrelated.mkdir()
|
||||
secret = unrelated / "secret.txt"
|
||||
secret.write_text("nope")
|
||||
|
||||
tool = ReadFileTool(
|
||||
workspace=workspace, allowed_dir=workspace,
|
||||
extra_allowed_dirs=[skills_dir],
|
||||
)
|
||||
result = await tool.execute(path=str(secret))
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_file_still_readable_with_extra_dirs(self, tmp_path):
|
||||
"""Adding extra_allowed_dirs must not break normal workspace reads."""
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
ws_file = workspace / "README.md"
|
||||
ws_file.write_text("hello from workspace")
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
tool = ReadFileTool(
|
||||
workspace=workspace, allowed_dir=workspace,
|
||||
extra_allowed_dirs=[skills_dir],
|
||||
)
|
||||
result = await tool.execute(path=str(ws_file))
|
||||
assert "hello from workspace" in result
|
||||
assert "Error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_blocked_in_extra_dir(self, tmp_path):
|
||||
"""edit_file must not be able to modify files in extra_allowed_dirs."""
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
skill_file = skills_dir / "weather" / "SKILL.md"
|
||||
skill_file.parent.mkdir()
|
||||
skill_file.write_text("# Weather\nOriginal content.")
|
||||
|
||||
tool = EditFileTool(workspace=workspace, allowed_dir=workspace)
|
||||
result = await tool.execute(
|
||||
path=str(skill_file),
|
||||
old_text="Original content.",
|
||||
new_text="Hacked content.",
|
||||
)
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
assert skill_file.read_text() == "# Weather\nOriginal content."
|
||||
@@ -0,0 +1,345 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
import sys
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.mcp import MCPToolWrapper, connect_mcp_servers
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.config.schema import MCPServerConfig
|
||||
|
||||
|
||||
class _FakeTextContent:
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_mcp_runtime() -> dict[str, object | None]:
|
||||
return {"session": None}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fake_mcp_module(
|
||||
monkeypatch: pytest.MonkeyPatch, fake_mcp_runtime: dict[str, object | None]
|
||||
) -> None:
|
||||
mod = ModuleType("mcp")
|
||||
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
|
||||
|
||||
class _FakeStdioServerParameters:
|
||||
def __init__(self, command: str, args: list[str], env: dict | None = None) -> None:
|
||||
self.command = command
|
||||
self.args = args
|
||||
self.env = env
|
||||
|
||||
class _FakeClientSession:
|
||||
def __init__(self, _read: object, _write: object) -> None:
|
||||
self._session = fake_mcp_runtime["session"]
|
||||
|
||||
async def __aenter__(self) -> object:
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_stdio_client(_params: object):
|
||||
yield object(), object()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_sse_client(_url: str, httpx_client_factory=None):
|
||||
yield object(), object()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_streamable_http_client(_url: str, http_client=None):
|
||||
yield object(), object(), object()
|
||||
|
||||
mod.ClientSession = _FakeClientSession
|
||||
mod.StdioServerParameters = _FakeStdioServerParameters
|
||||
monkeypatch.setitem(sys.modules, "mcp", mod)
|
||||
|
||||
client_mod = ModuleType("mcp.client")
|
||||
stdio_mod = ModuleType("mcp.client.stdio")
|
||||
stdio_mod.stdio_client = _fake_stdio_client
|
||||
sse_mod = ModuleType("mcp.client.sse")
|
||||
sse_mod.sse_client = _fake_sse_client
|
||||
streamable_http_mod = ModuleType("mcp.client.streamable_http")
|
||||
streamable_http_mod.streamable_http_client = _fake_streamable_http_client
|
||||
|
||||
monkeypatch.setitem(sys.modules, "mcp.client", client_mod)
|
||||
monkeypatch.setitem(sys.modules, "mcp.client.stdio", stdio_mod)
|
||||
monkeypatch.setitem(sys.modules, "mcp.client.sse", sse_mod)
|
||||
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", streamable_http_mod)
|
||||
|
||||
|
||||
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
|
||||
|
||||
|
||||
def test_wrapper_preserves_non_nullable_unions() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||
|
||||
assert wrapper.parameters["properties"]["value"]["anyOf"] == [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
]
|
||||
|
||||
|
||||
def test_wrapper_normalizes_nullable_property_type_union() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": ["string", "null"]},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||
|
||||
assert wrapper.parameters["properties"]["name"] == {"type": "string", "nullable": True}
|
||||
|
||||
|
||||
def test_wrapper_normalizes_nullable_property_anyof() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"description": "optional name",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||
|
||||
assert wrapper.parameters["properties"]["name"] == {
|
||||
"type": "string",
|
||||
"description": "optional name",
|
||||
"nullable": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_text_blocks() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
assert arguments == {"value": 1}
|
||||
return SimpleNamespace(content=[_FakeTextContent("hello"), 42])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute(value=1)
|
||||
|
||||
assert result == "hello\n42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_timeout_message() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(content=[])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=0.01)
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call timed out after 0.01s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handles_server_cancelled_error() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call was cancelled)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_re_raises_external_cancellation() -> None:
|
||||
started = asyncio.Event()
|
||||
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
started.set()
|
||||
await asyncio.sleep(60)
|
||||
return SimpleNamespace(content=[])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
||||
task = asyncio.create_task(wrapper.execute())
|
||||
await started.wait()
|
||||
|
||||
task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handles_generic_exception() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call failed: RuntimeError)"
|
||||
|
||||
|
||||
def _make_tool_def(name: str) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
name=name,
|
||||
description=f"{name} tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
|
||||
|
||||
def _make_fake_session(tool_names: list[str]) -> SimpleNamespace:
|
||||
async def initialize() -> None:
|
||||
return None
|
||||
|
||||
async def list_tools() -> SimpleNamespace:
|
||||
return SimpleNamespace(tools=[_make_tool_def(name) for name in tool_names])
|
||||
|
||||
return SimpleNamespace(initialize=initialize, list_tools=list_tools)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
||||
fake_mcp_runtime: dict[str, object | None],
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
||||
fake_mcp_runtime: dict[str, object | None],
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
||||
fake_mcp_runtime: dict[str, object | None],
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
||||
fake_mcp_runtime: dict[str, object | None],
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||
fake_mcp_runtime: dict[str, object | None], monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo"])
|
||||
registry = ToolRegistry()
|
||||
warnings: list[str] = []
|
||||
|
||||
def _warning(message: str, *args: object) -> None:
|
||||
warnings.append(message.format(*args))
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
assert warnings
|
||||
assert "enabledTools entries not found: unknown" in warnings[-1]
|
||||
assert "Available raw names: demo" in warnings[-1]
|
||||
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
||||
@@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_returns_error_when_no_target_context() -> None:
|
||||
tool = MessageTool()
|
||||
result = await tool.execute(content="test")
|
||||
assert result == "Error: No target channel/chat specified"
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Test message tool suppress logic for final replies."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
class TestMessageToolSuppressLogic:
|
||||
"""Final reply suppressed only when message tool sends to the same target."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert result is None # suppressed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert sent[0].channel == "email"
|
||||
assert result is not None # not suppressed
|
||||
assert result.channel == "feishu"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
calls = iter([
|
||||
LLMResponse(
|
||||
content="Visible<think>hidden</think>",
|
||||
tool_calls=[tool_call],
|
||||
reasoning_content="secret reasoning",
|
||||
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
|
||||
),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
progress: list[tuple[str, bool]] = []
|
||||
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
("Visible", False),
|
||||
('read_file("foo.txt")', True),
|
||||
]
|
||||
|
||||
|
||||
class TestMessageToolTurnTracking:
|
||||
|
||||
def test_sent_in_turn_tracks_same_target(self) -> None:
|
||||
tool = MessageTool()
|
||||
tool.set_context("feishu", "chat1")
|
||||
assert not tool._sent_in_turn
|
||||
tool._sent_in_turn = True
|
||||
assert tool._sent_in_turn
|
||||
|
||||
def test_start_turn_resets(self) -> None:
|
||||
tool = MessageTool()
|
||||
tool._sent_in_turn = True
|
||||
tool.start_turn()
|
||||
assert not tool._sent_in_turn
|
||||
@@ -0,0 +1,481 @@
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
|
||||
class SampleTool(Tool):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "sample"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "sample tool"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "minLength": 2},
|
||||
"count": {"type": "integer", "minimum": 1, "maximum": 10},
|
||||
"mode": {"type": "string", "enum": ["fast", "full"]},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tag": {"type": "string"},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["tag"],
|
||||
},
|
||||
},
|
||||
"required": ["query", "count"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_validate_params_missing_required() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "hi"})
|
||||
assert "missing required count" in "; ".join(errors)
|
||||
|
||||
|
||||
def test_validate_params_type_and_range() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "hi", "count": 0})
|
||||
assert any("count must be >= 1" in e for e in errors)
|
||||
|
||||
errors = tool.validate_params({"query": "hi", "count": "2"})
|
||||
assert any("count should be integer" in e for e in errors)
|
||||
|
||||
|
||||
def test_validate_params_enum_and_min_length() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "h", "count": 2, "mode": "slow"})
|
||||
assert any("query must be at least 2 chars" in e for e in errors)
|
||||
assert any("mode must be one of" in e for e in errors)
|
||||
|
||||
|
||||
def test_validate_params_nested_object_and_array() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params(
|
||||
{
|
||||
"query": "hi",
|
||||
"count": 2,
|
||||
"meta": {"flags": [1, "ok"]},
|
||||
}
|
||||
)
|
||||
assert any("missing required meta.tag" in e for e in errors)
|
||||
assert any("meta.flags[0] should be string" in e for e in errors)
|
||||
|
||||
|
||||
def test_validate_params_ignores_unknown_fields() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "hi", "count": 2, "extra": "x"})
|
||||
assert errors == []
|
||||
|
||||
|
||||
async def test_registry_returns_validation_error() -> None:
|
||||
reg = ToolRegistry()
|
||||
reg.register(SampleTool())
|
||||
result = await reg.execute("sample", {"query": "hi"})
|
||||
assert "Invalid parameters" in result
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
|
||||
cmd = r"type C:\user\workspace\txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert paths == [r"C:\user\workspace\txt"]
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
|
||||
cmd = ".venv/bin/python script.py"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/bin/python" not in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||
cmd = "cat /tmp/data.txt > /tmp/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "/tmp/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
|
||||
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
assert "~/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
|
||||
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
|
||||
|
||||
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
class CastTestTool(Tool):
|
||||
"""Minimal tool for testing cast_params."""
|
||||
|
||||
def __init__(self, schema: dict[str, Any]) -> None:
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cast_test"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "test tool for casting"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._schema
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_cast_params_string_to_int() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "42"})
|
||||
assert result["count"] == 42
|
||||
assert isinstance(result["count"], int)
|
||||
|
||||
|
||||
def test_cast_params_string_to_number() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "3.14"})
|
||||
assert result["rate"] == 3.14
|
||||
assert isinstance(result["rate"], float)
|
||||
|
||||
|
||||
def test_cast_params_string_to_bool() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"enabled": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"enabled": "true"})["enabled"] is True
|
||||
assert tool.cast_params({"enabled": "false"})["enabled"] is False
|
||||
assert tool.cast_params({"enabled": "1"})["enabled"] is True
|
||||
|
||||
|
||||
def test_cast_params_array_items() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nums": {"type": "array", "items": {"type": "integer"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"nums": ["1", "2", "3"]})
|
||||
assert result["nums"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_cast_params_nested_object() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"port": {"type": "integer"},
|
||||
"debug": {"type": "boolean"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
|
||||
assert result["config"]["port"] == 8080
|
||||
assert result["config"]["debug"] is True
|
||||
|
||||
|
||||
def test_cast_params_bool_not_cast_to_int() -> None:
|
||||
"""Booleans should not be silently cast to integers."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": True})
|
||||
assert result["count"] is True
|
||||
errors = tool.validate_params(result)
|
||||
assert any("count should be integer" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_preserves_empty_string() -> None:
|
||||
"""Empty strings should be preserved for string type."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"name": ""})
|
||||
assert result["name"] == ""
|
||||
|
||||
|
||||
def test_cast_params_bool_string_false() -> None:
|
||||
"""Test that 'false', '0', 'no' strings convert to False."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"flag": "false"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "False"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "0"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "no"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "NO"})["flag"] is False
|
||||
|
||||
|
||||
def test_cast_params_bool_string_invalid() -> None:
|
||||
"""Invalid boolean strings should not be cast."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
# Invalid strings should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"flag": "random"})
|
||||
assert result["flag"] == "random"
|
||||
result = tool.cast_params({"flag": "maybe"})
|
||||
assert result["flag"] == "maybe"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_int() -> None:
|
||||
"""Invalid strings should not be cast to integer."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "abc"})
|
||||
assert result["count"] == "abc" # Original value preserved
|
||||
result = tool.cast_params({"count": "12.5.7"})
|
||||
assert result["count"] == "12.5.7"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_number() -> None:
|
||||
"""Invalid strings should not be cast to number."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "not_a_number"})
|
||||
assert result["rate"] == "not_a_number"
|
||||
|
||||
|
||||
def test_validate_params_bool_not_accepted_as_number() -> None:
|
||||
"""Booleans should not pass number validation."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"rate": False})
|
||||
assert any("rate should be number" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_none_values() -> None:
|
||||
"""Test None handling for different types."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "integer"},
|
||||
"items": {"type": "array"},
|
||||
"config": {"type": "object"},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params(
|
||||
{
|
||||
"name": None,
|
||||
"count": None,
|
||||
"items": None,
|
||||
"config": None,
|
||||
}
|
||||
)
|
||||
# None should be preserved for all types
|
||||
assert result["name"] is None
|
||||
assert result["count"] is None
|
||||
assert result["items"] is None
|
||||
assert result["config"] is None
|
||||
|
||||
|
||||
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
||||
"""Single values should NOT be automatically wrapped into arrays."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"items": {"type": "array"}},
|
||||
}
|
||||
)
|
||||
# Non-array values should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"items": 5})
|
||||
assert result["items"] == 5 # Not wrapped to [5]
|
||||
result = tool.cast_params({"items": "text"})
|
||||
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||
|
||||
|
||||
# --- ExecTool enhancement tests ---
|
||||
|
||||
|
||||
async def test_exec_always_returns_exit_code() -> None:
|
||||
"""Exit code should appear in output even on success (exit 0)."""
|
||||
tool = ExecTool()
|
||||
result = await tool.execute(command="echo hello")
|
||||
assert "Exit code: 0" in result
|
||||
assert "hello" in result
|
||||
|
||||
|
||||
async def test_exec_head_tail_truncation() -> None:
|
||||
"""Long output should preserve both head and tail."""
|
||||
tool = ExecTool()
|
||||
# Generate output that exceeds _MAX_OUTPUT (10_000 chars)
|
||||
# Use python to generate output to avoid command line length limits
|
||||
result = await tool.execute(
|
||||
command="python -c \"print('A' * 6000 + '\\n' + 'B' * 6000)\""
|
||||
)
|
||||
assert "chars truncated" in result
|
||||
# Head portion should start with As
|
||||
assert result.startswith("A")
|
||||
# Tail portion should end with the exit code which comes after Bs
|
||||
assert "Exit code:" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_parameter() -> None:
|
||||
"""LLM-supplied timeout should override the constructor default."""
|
||||
tool = ExecTool(timeout=60)
|
||||
# A very short timeout should cause the command to be killed
|
||||
result = await tool.execute(command="sleep 10", timeout=1)
|
||||
assert "timed out" in result
|
||||
assert "1 seconds" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_capped_at_max() -> None:
|
||||
"""Timeout values above _MAX_TIMEOUT should be clamped."""
|
||||
tool = ExecTool()
|
||||
# Should not raise — just clamp to 600
|
||||
result = await tool.execute(command="echo ok", timeout=9999)
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
# --- _resolve_type and nullable param tests ---
|
||||
|
||||
|
||||
def test_resolve_type_simple_string() -> None:
|
||||
"""Simple string type passes through unchanged."""
|
||||
assert Tool._resolve_type("string") == "string"
|
||||
|
||||
|
||||
def test_resolve_type_union_with_null() -> None:
|
||||
"""Union type ['string', 'null'] resolves to 'string'."""
|
||||
assert Tool._resolve_type(["string", "null"]) == "string"
|
||||
|
||||
|
||||
def test_resolve_type_only_null() -> None:
|
||||
"""Union type ['null'] resolves to None (no non-null type)."""
|
||||
assert Tool._resolve_type(["null"]) is None
|
||||
|
||||
|
||||
def test_resolve_type_none_input() -> None:
|
||||
"""None input passes through as None."""
|
||||
assert Tool._resolve_type(None) is None
|
||||
|
||||
|
||||
def test_validate_nullable_param_accepts_string() -> None:
|
||||
"""Nullable string param should accept a string value."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": ["string", "null"]}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"name": "hello"})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_nullable_param_accepts_none() -> None:
|
||||
"""Nullable string param should accept None."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": ["string", "null"]}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"name": None})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_nullable_flag_accepts_none() -> None:
|
||||
"""OpenAI-normalized nullable params should still accept None locally."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string", "nullable": True}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"name": None})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_cast_nullable_param_no_crash() -> None:
|
||||
"""cast_params should not crash on nullable type (the original bug)."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": ["string", "null"]}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"name": "hello"})
|
||||
assert result["name"] == "hello"
|
||||
result = tool.cast_params({"name": None})
|
||||
assert result["name"] is None
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Tests for web_fetch SSRF protection and untrusted content marking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.web import WebFetchTool
|
||||
|
||||
|
||||
def _fake_resolve_private(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
|
||||
|
||||
|
||||
def _fake_resolve_public(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_fetch_blocks_private_ip():
|
||||
tool = WebFetchTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
|
||||
result = await tool.execute(url="http://169.254.169.254/computeMetadata/v1/")
|
||||
data = json.loads(result)
|
||||
assert "error" in data
|
||||
assert "private" in data["error"].lower() or "blocked" in data["error"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_fetch_blocks_localhost():
|
||||
tool = WebFetchTool()
|
||||
def _resolve_localhost(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _resolve_localhost):
|
||||
result = await tool.execute(url="http://localhost/admin")
|
||||
data = json.loads(result)
|
||||
assert "error" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_fetch_result_contains_untrusted_flag():
|
||||
"""When fetch succeeds, result JSON must include untrusted=True and the banner."""
|
||||
tool = WebFetchTool()
|
||||
|
||||
fake_html = "<html><head><title>Test</title></head><body><p>Hello world</p></body></html>"
|
||||
|
||||
import httpx
|
||||
|
||||
class FakeResponse:
|
||||
status_code = 200
|
||||
url = "https://example.com/page"
|
||||
text = fake_html
|
||||
headers = {"content-type": "text/html"}
|
||||
def raise_for_status(self): pass
|
||||
def json(self): return {}
|
||||
|
||||
async def _fake_get(self, url, **kwargs):
|
||||
return FakeResponse()
|
||||
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public), \
|
||||
patch("httpx.AsyncClient.get", _fake_get):
|
||||
result = await tool.execute(url="https://example.com/page")
|
||||
|
||||
data = json.loads(result)
|
||||
assert data.get("untrusted") is True
|
||||
assert "[External content" in data.get("text", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_fetch_blocks_private_redirect_before_returning_image(monkeypatch):
|
||||
tool = WebFetchTool()
|
||||
|
||||
class FakeStreamResponse:
|
||||
headers = {"content-type": "image/png"}
|
||||
url = "http://127.0.0.1/secret.png"
|
||||
content = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def aread(self):
|
||||
return self.content
|
||||
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def stream(self, method, url, headers=None):
|
||||
return FakeStreamResponse()
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.web.httpx.AsyncClient", FakeClient)
|
||||
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
|
||||
result = await tool.execute(url="https://example.com/image.png")
|
||||
|
||||
data = json.loads(result)
|
||||
assert "error" in data
|
||||
assert "redirect blocked" in data["error"].lower()
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Tests for multi-provider web search."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.web import WebSearchTool
|
||||
from nanobot.config.schema import WebSearchConfig
|
||||
|
||||
|
||||
def _tool(provider: str = "brave", api_key: str = "", base_url: str = "") -> WebSearchTool:
|
||||
return WebSearchTool(config=WebSearchConfig(provider=provider, api_key=api_key, base_url=base_url))
|
||||
|
||||
|
||||
def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
|
||||
"""Build a mock httpx.Response with a dummy request attached."""
|
||||
r = httpx.Response(status, json=json)
|
||||
r._request = httpx.Request("GET", "https://mock")
|
||||
return r
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brave_search(monkeypatch):
|
||||
async def mock_get(self, url, **kw):
|
||||
assert "brave" in url
|
||||
assert kw["headers"]["X-Subscription-Token"] == "brave-key"
|
||||
return _response(json={
|
||||
"web": {"results": [{"title": "NanoBot", "url": "https://example.com", "description": "AI assistant"}]}
|
||||
})
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
|
||||
tool = _tool(provider="brave", api_key="brave-key")
|
||||
result = await tool.execute(query="nanobot", count=1)
|
||||
assert "NanoBot" in result
|
||||
assert "https://example.com" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tavily_search(monkeypatch):
|
||||
async def mock_post(self, url, **kw):
|
||||
assert "tavily" in url
|
||||
assert kw["headers"]["Authorization"] == "Bearer tavily-key"
|
||||
return _response(json={
|
||||
"results": [{"title": "OpenClaw", "url": "https://openclaw.io", "content": "Framework"}]
|
||||
})
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
tool = _tool(provider="tavily", api_key="tavily-key")
|
||||
result = await tool.execute(query="openclaw")
|
||||
assert "OpenClaw" in result
|
||||
assert "https://openclaw.io" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searxng_search(monkeypatch):
|
||||
async def mock_get(self, url, **kw):
|
||||
assert "searx.example" in url
|
||||
return _response(json={
|
||||
"results": [{"title": "Result", "url": "https://example.com", "content": "SearXNG result"}]
|
||||
})
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
|
||||
tool = _tool(provider="searxng", base_url="https://searx.example")
|
||||
result = await tool.execute(query="test")
|
||||
assert "Result" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duckduckgo_search(monkeypatch):
|
||||
class MockDDGS:
|
||||
def __init__(self, **kw):
|
||||
pass
|
||||
|
||||
def text(self, query, max_results=5):
|
||||
return [{"title": "DDG Result", "href": "https://ddg.example", "body": "From DuckDuckGo"}]
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.web.DDGS", MockDDGS, raising=False)
|
||||
import nanobot.agent.tools.web as web_mod
|
||||
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
|
||||
|
||||
from ddgs import DDGS
|
||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||
|
||||
tool = _tool(provider="duckduckgo")
|
||||
result = await tool.execute(query="hello")
|
||||
assert "DDG Result" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brave_fallback_to_duckduckgo_when_no_key(monkeypatch):
|
||||
class MockDDGS:
|
||||
def __init__(self, **kw):
|
||||
pass
|
||||
|
||||
def text(self, query, max_results=5):
|
||||
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
|
||||
|
||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
|
||||
|
||||
tool = _tool(provider="brave", api_key="")
|
||||
result = await tool.execute(query="test")
|
||||
assert "Fallback" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jina_search(monkeypatch):
|
||||
async def mock_get(self, url, **kw):
|
||||
assert "s.jina.ai" in str(url)
|
||||
assert kw["headers"]["Authorization"] == "Bearer jina-key"
|
||||
return _response(json={
|
||||
"data": [{"title": "Jina Result", "url": "https://jina.ai", "content": "AI search"}]
|
||||
})
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
|
||||
tool = _tool(provider="jina", api_key="jina-key")
|
||||
result = await tool.execute(query="test")
|
||||
assert "Jina Result" in result
|
||||
assert "https://jina.ai" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_provider():
|
||||
tool = _tool(provider="unknown")
|
||||
result = await tool.execute(query="test")
|
||||
assert "unknown" in result
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_provider_is_brave(monkeypatch):
|
||||
async def mock_get(self, url, **kw):
|
||||
assert "brave" in url
|
||||
return _response(json={"web": {"results": []}})
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
|
||||
tool = _tool(provider="", api_key="test-key")
|
||||
result = await tool.execute(query="test")
|
||||
assert "No results" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searxng_no_base_url_falls_back(monkeypatch):
|
||||
class MockDDGS:
|
||||
def __init__(self, **kw):
|
||||
pass
|
||||
|
||||
def text(self, query, max_results=5):
|
||||
return [{"title": "Fallback", "href": "https://ddg.example", "body": "fallback"}]
|
||||
|
||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||
monkeypatch.delenv("SEARXNG_BASE_URL", raising=False)
|
||||
|
||||
tool = _tool(provider="searxng", base_url="")
|
||||
result = await tool.execute(query="test")
|
||||
assert "Fallback" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searxng_invalid_url():
|
||||
tool = _tool(provider="searxng", base_url="not-a-url")
|
||||
result = await tool.execute(query="test")
|
||||
assert "Error" in result
|
||||
Reference in New Issue
Block a user