Update 2026-05-13 16:43:53

This commit is contained in:
yi
2026-05-13 16:43:53 +08:00
parent 6af5c584f4
commit afd7c5fe85
490 changed files with 850 additions and 922 deletions
@@ -0,0 +1,619 @@
"""Test session management with cache-friendly message handling."""
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from pathlib import Path
from nanobot.session.manager import Session, SessionManager
# Test constants
MEMORY_WINDOW = 50
KEEP_COUNT = MEMORY_WINDOW // 2 # 25
def create_session_with_messages(key: str, count: int, role: str = "user") -> Session:
"""Create a session and add the specified number of messages.
Args:
key: Session identifier
count: Number of messages to add
role: Message role (default: "user")
Returns:
Session with the specified messages
"""
session = Session(key=key)
for i in range(count):
session.add_message(role, f"msg{i}")
return session
def assert_messages_content(messages: list, start_index: int, end_index: int) -> None:
"""Assert that messages contain expected content from start to end index.
Args:
messages: List of message dictionaries
start_index: Expected first message index
end_index: Expected last message index
"""
assert len(messages) > 0
assert messages[0]["content"] == f"msg{start_index}"
assert messages[-1]["content"] == f"msg{end_index}"
def get_old_messages(session: Session, last_consolidated: int, keep_count: int) -> list:
"""Extract messages that would be consolidated using the standard slice logic.
Args:
session: The session containing messages
last_consolidated: Index of last consolidated message
keep_count: Number of recent messages to keep
Returns:
List of messages that would be consolidated
"""
return session.messages[last_consolidated:-keep_count]
class TestSessionLastConsolidated:
"""Test last_consolidated tracking to avoid duplicate processing."""
def test_initial_last_consolidated_zero(self) -> None:
"""Test that new session starts with last_consolidated=0."""
session = Session(key="test:initial")
assert session.last_consolidated == 0
def test_last_consolidated_persistence(self, tmp_path) -> None:
"""Test that last_consolidated persists across save/load."""
manager = SessionManager(Path(tmp_path))
session1 = create_session_with_messages("test:persist", 20)
session1.last_consolidated = 15
manager.save(session1)
session2 = manager.get_or_create("test:persist")
assert session2.last_consolidated == 15
assert len(session2.messages) == 20
def test_clear_resets_last_consolidated(self) -> None:
"""Test that clear() resets last_consolidated to 0."""
session = create_session_with_messages("test:clear", 10)
session.last_consolidated = 5
session.clear()
assert len(session.messages) == 0
assert session.last_consolidated == 0
class TestSessionImmutableHistory:
"""Test Session message immutability for cache efficiency."""
def test_initial_state(self) -> None:
"""Test that new session has empty messages list."""
session = Session(key="test:initial")
assert len(session.messages) == 0
def test_add_messages_appends_only(self) -> None:
"""Test that adding messages only appends, never modifies."""
session = Session(key="test:preserve")
session.add_message("user", "msg1")
session.add_message("assistant", "resp1")
session.add_message("user", "msg2")
assert len(session.messages) == 3
assert session.messages[0]["content"] == "msg1"
def test_get_history_returns_most_recent(self) -> None:
"""Test get_history returns the most recent messages."""
session = Session(key="test:history")
for i in range(10):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
history = session.get_history(max_messages=6)
assert len(history) == 6
assert history[0]["content"] == "msg7"
assert history[-1]["content"] == "resp9"
def test_get_history_with_all_messages(self) -> None:
"""Test get_history with max_messages larger than actual."""
session = create_session_with_messages("test:all", 5)
history = session.get_history(max_messages=100)
assert len(history) == 5
assert history[0]["content"] == "msg0"
def test_get_history_stable_for_same_session(self) -> None:
"""Test that get_history returns same content for same max_messages."""
session = create_session_with_messages("test:stable", 20)
history1 = session.get_history(max_messages=10)
history2 = session.get_history(max_messages=10)
assert history1 == history2
def test_messages_list_never_modified(self) -> None:
"""Test that messages list is never modified after creation."""
session = create_session_with_messages("test:immutable", 5)
original_len = len(session.messages)
session.get_history(max_messages=2)
assert len(session.messages) == original_len
for _ in range(10):
session.get_history(max_messages=3)
assert len(session.messages) == original_len
class TestSessionPersistence:
"""Test Session persistence and reload."""
@pytest.fixture
def temp_manager(self, tmp_path):
return SessionManager(Path(tmp_path))
def test_persistence_roundtrip(self, temp_manager):
"""Test that messages persist across save/load."""
session1 = create_session_with_messages("test:persistence", 20)
temp_manager.save(session1)
session2 = temp_manager.get_or_create("test:persistence")
assert len(session2.messages) == 20
assert session2.messages[0]["content"] == "msg0"
assert session2.messages[-1]["content"] == "msg19"
def test_get_history_after_reload(self, temp_manager):
"""Test that get_history works correctly after reload."""
session1 = create_session_with_messages("test:reload", 30)
temp_manager.save(session1)
session2 = temp_manager.get_or_create("test:reload")
history = session2.get_history(max_messages=10)
assert len(history) == 10
assert history[0]["content"] == "msg20"
assert history[-1]["content"] == "msg29"
def test_clear_resets_session(self, temp_manager):
"""Test that clear() properly resets session."""
session = create_session_with_messages("test:clear", 10)
assert len(session.messages) == 10
session.clear()
assert len(session.messages) == 0
class TestConsolidationTriggerConditions:
"""Test consolidation trigger conditions and logic."""
def test_consolidation_needed_when_messages_exceed_window(self):
"""Test consolidation logic: should trigger when messages exceed the window."""
session = create_session_with_messages("test:trigger", 60)
total_messages = len(session.messages)
messages_to_process = total_messages - session.last_consolidated
assert total_messages > MEMORY_WINDOW
assert messages_to_process > 0
expected_consolidate_count = total_messages - KEEP_COUNT
assert expected_consolidate_count == 35
def test_consolidation_skipped_when_within_keep_count(self):
"""Test consolidation skipped when total messages <= keep_count."""
session = create_session_with_messages("test:skip", 20)
total_messages = len(session.messages)
assert total_messages <= KEEP_COUNT
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
assert len(old_messages) == 0
def test_consolidation_skipped_when_no_new_messages(self):
"""Test consolidation skipped when messages_to_process <= 0."""
session = create_session_with_messages("test:already_consolidated", 40)
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
# Add a few more messages
for i in range(40, 42):
session.add_message("user", f"msg{i}")
total_messages = len(session.messages)
messages_to_process = total_messages - session.last_consolidated
assert messages_to_process > 0
# Simulate last_consolidated catching up
session.last_consolidated = total_messages - KEEP_COUNT
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
assert len(old_messages) == 0
class TestLastConsolidatedEdgeCases:
"""Test last_consolidated edge cases and data corruption scenarios."""
def test_last_consolidated_exceeds_message_count(self):
"""Test behavior when last_consolidated > len(messages) (data corruption)."""
session = create_session_with_messages("test:corruption", 10)
session.last_consolidated = 20
total_messages = len(session.messages)
messages_to_process = total_messages - session.last_consolidated
assert messages_to_process <= 0
old_messages = get_old_messages(session, session.last_consolidated, 5)
assert len(old_messages) == 0
def test_last_consolidated_negative_value(self):
"""Test behavior with negative last_consolidated (invalid state)."""
session = create_session_with_messages("test:negative", 10)
session.last_consolidated = -5
keep_count = 3
old_messages = get_old_messages(session, session.last_consolidated, keep_count)
# messages[-5:-3] with 10 messages gives indices 5,6
assert len(old_messages) == 2
assert old_messages[0]["content"] == "msg5"
assert old_messages[-1]["content"] == "msg6"
def test_messages_added_after_consolidation(self):
"""Test correct behavior when new messages arrive after consolidation."""
session = create_session_with_messages("test:new_messages", 40)
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
# Add new messages after consolidation
for i in range(40, 50):
session.add_message("user", f"msg{i}")
total_messages = len(session.messages)
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
expected_consolidate_count = total_messages - KEEP_COUNT - session.last_consolidated
assert len(old_messages) == expected_consolidate_count
assert_messages_content(old_messages, 15, 24)
def test_slice_behavior_when_indices_overlap(self):
"""Test slice behavior when last_consolidated >= total - keep_count."""
session = create_session_with_messages("test:overlap", 30)
session.last_consolidated = 12
old_messages = get_old_messages(session, session.last_consolidated, 20)
assert len(old_messages) == 0
class TestArchiveAllMode:
"""Test archive_all mode (used by /new command)."""
def test_archive_all_consolidates_everything(self):
"""Test archive_all=True consolidates all messages."""
session = create_session_with_messages("test:archive_all", 50)
archive_all = True
if archive_all:
old_messages = session.messages
assert len(old_messages) == 50
assert session.last_consolidated == 0
def test_archive_all_resets_last_consolidated(self):
"""Test that archive_all mode resets last_consolidated to 0."""
session = create_session_with_messages("test:reset", 40)
session.last_consolidated = 15
archive_all = True
if archive_all:
session.last_consolidated = 0
assert session.last_consolidated == 0
assert len(session.messages) == 40
def test_archive_all_vs_normal_consolidation(self):
"""Test difference between archive_all and normal consolidation."""
# Normal consolidation
session1 = create_session_with_messages("test:normal", 60)
session1.last_consolidated = len(session1.messages) - KEEP_COUNT
# archive_all mode
session2 = create_session_with_messages("test:all", 60)
session2.last_consolidated = 0
assert session1.last_consolidated == 35
assert len(session1.messages) == 60
assert session2.last_consolidated == 0
assert len(session2.messages) == 60
class TestCacheImmutability:
"""Test that consolidation doesn't modify session.messages (cache safety)."""
def test_consolidation_does_not_modify_messages_list(self):
"""Test that consolidation leaves messages list unchanged."""
session = create_session_with_messages("test:immutable", 50)
original_messages = session.messages.copy()
original_len = len(session.messages)
session.last_consolidated = original_len - KEEP_COUNT
assert len(session.messages) == original_len
assert session.messages == original_messages
def test_get_history_does_not_modify_messages(self):
"""Test that get_history doesn't modify messages list."""
session = create_session_with_messages("test:history_immutable", 40)
original_messages = [m.copy() for m in session.messages]
for _ in range(5):
history = session.get_history(max_messages=10)
assert len(history) == 10
assert len(session.messages) == 40
for i, msg in enumerate(session.messages):
assert msg["content"] == original_messages[i]["content"]
def test_consolidation_only_updates_last_consolidated(self):
"""Test that consolidation only updates last_consolidated field."""
session = create_session_with_messages("test:field_only", 60)
original_messages = session.messages.copy()
original_key = session.key
original_metadata = session.metadata.copy()
session.last_consolidated = len(session.messages) - KEEP_COUNT
assert session.messages == original_messages
assert session.key == original_key
assert session.metadata == original_metadata
assert session.last_consolidated == 35
class TestSliceLogic:
"""Test the slice logic: messages[last_consolidated:-keep_count]."""
def test_slice_extracts_correct_range(self):
"""Test that slice extracts the correct message range."""
session = create_session_with_messages("test:slice", 60)
old_messages = get_old_messages(session, 0, KEEP_COUNT)
assert len(old_messages) == 35
assert_messages_content(old_messages, 0, 34)
remaining = session.messages[-KEEP_COUNT:]
assert len(remaining) == 25
assert_messages_content(remaining, 35, 59)
def test_slice_with_partial_consolidation(self):
"""Test slice when some messages already consolidated."""
session = create_session_with_messages("test:partial", 70)
last_consolidated = 30
old_messages = get_old_messages(session, last_consolidated, KEEP_COUNT)
assert len(old_messages) == 15
assert_messages_content(old_messages, 30, 44)
def test_slice_with_various_keep_counts(self):
"""Test slice behavior with different keep_count values."""
session = create_session_with_messages("test:keep_counts", 50)
test_cases = [(10, 40), (20, 30), (30, 20), (40, 10)]
for keep_count, expected_count in test_cases:
old_messages = session.messages[0:-keep_count]
assert len(old_messages) == expected_count
def test_slice_when_keep_count_exceeds_messages(self):
"""Test slice when keep_count > len(messages)."""
session = create_session_with_messages("test:exceed", 10)
old_messages = session.messages[0:-20]
assert len(old_messages) == 0
class TestEmptyAndBoundarySessions:
"""Test empty sessions and boundary conditions."""
def test_empty_session_consolidation(self):
"""Test consolidation behavior with empty session."""
session = Session(key="test:empty")
assert len(session.messages) == 0
assert session.last_consolidated == 0
messages_to_process = len(session.messages) - session.last_consolidated
assert messages_to_process == 0
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
assert len(old_messages) == 0
def test_single_message_session(self):
"""Test consolidation with single message."""
session = Session(key="test:single")
session.add_message("user", "only message")
assert len(session.messages) == 1
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
assert len(old_messages) == 0
def test_exactly_keep_count_messages(self):
"""Test session with exactly keep_count messages."""
session = create_session_with_messages("test:exact", KEEP_COUNT)
assert len(session.messages) == KEEP_COUNT
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
assert len(old_messages) == 0
def test_just_over_keep_count(self):
"""Test session with one message over keep_count."""
session = create_session_with_messages("test:over", KEEP_COUNT + 1)
assert len(session.messages) == 26
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
assert len(old_messages) == 1
assert old_messages[0]["content"] == "msg0"
def test_very_large_session(self):
"""Test consolidation with very large message count."""
session = create_session_with_messages("test:large", 1000)
assert len(session.messages) == 1000
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
assert len(old_messages) == 975
assert_messages_content(old_messages, 0, 974)
remaining = session.messages[-KEEP_COUNT:]
assert len(remaining) == 25
assert_messages_content(remaining, 975, 999)
def test_session_with_gaps_in_consolidation(self):
"""Test session with potential gaps in consolidation history."""
session = create_session_with_messages("test:gaps", 50)
session.last_consolidated = 10
# Add more messages
for i in range(50, 60):
session.add_message("user", f"msg{i}")
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
expected_count = 60 - KEEP_COUNT - 10
assert len(old_messages) == expected_count
assert_messages_content(old_messages, 10, 34)
class TestNewCommandArchival:
"""Test /new archival behavior with the simplified consolidation flow."""
@staticmethod
def _make_loop(tmp_path: Path):
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (10_000, "test")
loop = AgentLoop(
bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=1,
)
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
return loop
@pytest.mark.asyncio
async def test_new_clears_session_immediately_even_if_archive_fails(self, tmp_path: Path) -> None:
"""/new clears session immediately; archive_messages retries until raw dump."""
from nanobot.bus.events import InboundMessage
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
call_count = 0
async def _failing_consolidate(_messages) -> bool:
nonlocal call_count
call_count += 1
return False
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == 0
await loop.close_mcp()
assert call_count == 3 # retried up to raw-archive threshold
@pytest.mark.asyncio
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
from nanobot.bus.events import InboundMessage
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
session.last_consolidated = len(session.messages) - 3
loop.sessions.save(session)
archived_count = -1
async def _fake_consolidate(messages) -> bool:
nonlocal archived_count
archived_count = len(messages)
return True
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
await loop.close_mcp()
assert archived_count == 3
@pytest.mark.asyncio
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
from nanobot.bus.events import InboundMessage
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
async def _ok_consolidate(_messages) -> bool:
return True
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
assert loop.sessions.get_or_create("cli:test").messages == []
@pytest.mark.asyncio
async def test_close_mcp_drains_background_tasks(self, tmp_path: Path) -> None:
"""close_mcp waits for background tasks to complete."""
from nanobot.bus.events import InboundMessage
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
archived = asyncio.Event()
async def _slow_consolidate(_messages) -> bool:
await asyncio.sleep(0.1)
archived.set()
return True
loop.memory_consolidator.consolidate_messages = _slow_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
assert not archived.is_set()
await loop.close_mcp()
assert archived.is_set()
@@ -0,0 +1,73 @@
"""Tests for cache-friendly prompt construction."""
from __future__ import annotations
from datetime import datetime as real_datetime
from importlib.resources import files as pkg_files
from pathlib import Path
import datetime as datetime_module
from nanobot.agent.context import ContextBuilder
class _FakeDatetime(real_datetime):
current = real_datetime(2026, 2, 24, 13, 59)
@classmethod
def now(cls, tz=None): # type: ignore[override]
return cls.current
def _make_workspace(tmp_path: Path) -> Path:
workspace = tmp_path / "workspace"
workspace.mkdir(parents=True)
return workspace
def test_bootstrap_files_are_backed_by_templates() -> None:
template_dir = pkg_files("nanobot") / "templates"
for filename in ContextBuilder.BOOTSTRAP_FILES:
assert (template_dir / filename).is_file(), f"missing bootstrap template: {filename}"
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
"""System prompt should not change just because wall clock minute changes."""
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
_FakeDatetime.current = real_datetime(2026, 2, 24, 13, 59)
prompt1 = builder.build_system_prompt()
_FakeDatetime.current = real_datetime(2026, 2, 24, 14, 0)
prompt2 = builder.build_system_prompt()
assert prompt1 == prompt2
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
"""Runtime metadata should be merged with the user message."""
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
messages = builder.build_messages(
history=[],
current_message="Return exactly: OK",
channel="cli",
chat_id="direct",
)
assert messages[0]["role"] == "system"
assert "## Current Session" not in messages[0]["content"]
# Runtime context is now merged with user message into a single message
assert messages[-1]["role"] == "user"
user_content = messages[-1]["content"]
assert isinstance(user_content, str)
assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
assert "Current Time:" in user_content
assert "Channel: cli" in user_content
assert "Chat ID: direct" in user_content
assert "Return exactly: OK" in user_content
+63
View File
@@ -0,0 +1,63 @@
import pytest
from nanobot.utils.evaluator import evaluate_response
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class DummyProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
async def chat(self, *args, **kwargs) -> LLMResponse:
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
def _eval_tool_call(should_notify: bool, reason: str = "") -> LLMResponse:
return LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="eval_1",
name="evaluate_notification",
arguments={"should_notify": should_notify, "reason": reason},
)
],
)
@pytest.mark.asyncio
async def test_should_notify_true() -> None:
provider = DummyProvider([_eval_tool_call(True, "user asked to be reminded")])
result = await evaluate_response("Task completed with results", "check emails", provider, "m")
assert result is True
@pytest.mark.asyncio
async def test_should_notify_false() -> None:
provider = DummyProvider([_eval_tool_call(False, "routine check, nothing new")])
result = await evaluate_response("All clear, no updates", "check status", provider, "m")
assert result is False
@pytest.mark.asyncio
async def test_fallback_on_error() -> None:
class FailingProvider(DummyProvider):
async def chat(self, *args, **kwargs) -> LLMResponse:
raise RuntimeError("provider down")
provider = FailingProvider([])
result = await evaluate_response("some response", "some task", provider, "m")
assert result is True
@pytest.mark.asyncio
async def test_no_tool_call_fallback() -> None:
provider = DummyProvider([LLMResponse(content="I think you should notify", tool_calls=[])])
result = await evaluate_response("some response", "some task", provider, "m")
assert result is True
@@ -0,0 +1,200 @@
"""Tests for Gemini thought_signature round-trip through extra_content.
The Gemini OpenAI-compatibility API returns tool calls with an extra_content
field: ``{"google": {"thought_signature": "..."}}``. This MUST survive the
parse → serialize round-trip so the model can continue reasoning.
"""
from types import SimpleNamespace
from unittest.mock import patch
from nanobot.providers.base import ToolCallRequest
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
GEMINI_EXTRA = {"google": {"thought_signature": "sig-abc-123"}}
# ── ToolCallRequest serialization ──────────────────────────────────────
def test_tool_call_request_serializes_extra_content() -> None:
tc = ToolCallRequest(
id="abc123xyz",
name="read_file",
arguments={"path": "todo.md"},
extra_content=GEMINI_EXTRA,
)
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
assert payload["function"]["arguments"] == '{"path": "todo.md"}'
def test_tool_call_request_serializes_provider_fields() -> None:
tc = ToolCallRequest(
id="abc123xyz",
name="read_file",
arguments={"path": "todo.md"},
provider_specific_fields={"custom_key": "custom_val"},
function_provider_specific_fields={"inner": "value"},
)
payload = tc.to_openai_tool_call()
assert payload["provider_specific_fields"] == {"custom_key": "custom_val"}
assert payload["function"]["provider_specific_fields"] == {"inner": "value"}
def test_tool_call_request_omits_absent_extras() -> None:
tc = ToolCallRequest(id="x", name="fn", arguments={})
payload = tc.to_openai_tool_call()
assert "extra_content" not in payload
assert "provider_specific_fields" not in payload
assert "provider_specific_fields" not in payload["function"]
# ── _parse: SDK-object branch ──────────────────────────────────────────
def _make_sdk_response_with_extra_content():
"""Simulate a Gemini response via the OpenAI SDK (SimpleNamespace)."""
fn = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}')
tc = SimpleNamespace(
id="call_1",
index=0,
type="function",
function=fn,
extra_content=GEMINI_EXTRA,
)
msg = SimpleNamespace(
content=None,
tool_calls=[tc],
reasoning_content=None,
)
choice = SimpleNamespace(message=msg, finish_reason="tool_calls")
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
return SimpleNamespace(choices=[choice], usage=usage)
def test_parse_sdk_object_preserves_extra_content() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
result = provider._parse(_make_sdk_response_with_extra_content())
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.name == "get_weather"
assert tc.extra_content == GEMINI_EXTRA
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
# ── _parse: dict/mapping branch ───────────────────────────────────────
def test_parse_dict_preserves_extra_content() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
response_dict = {
"choices": [{
"message": {
"content": None,
"tool_calls": [{
"id": "call_1",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'},
"extra_content": GEMINI_EXTRA,
}],
},
"finish_reason": "tool_calls",
}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
result = provider._parse(response_dict)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.name == "get_weather"
assert tc.extra_content == GEMINI_EXTRA
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
# ── _parse_chunks: streaming round-trip ───────────────────────────────
def test_parse_chunks_sdk_preserves_extra_content() -> None:
fn_delta = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}')
tc_delta = SimpleNamespace(
id="call_1",
index=0,
function=fn_delta,
extra_content=GEMINI_EXTRA,
)
delta = SimpleNamespace(content=None, tool_calls=[tc_delta])
choice = SimpleNamespace(finish_reason="tool_calls", delta=delta)
chunk = SimpleNamespace(choices=[choice], usage=None)
result = OpenAICompatProvider._parse_chunks([chunk])
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.extra_content == GEMINI_EXTRA
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
def test_parse_chunks_dict_preserves_extra_content() -> None:
chunk = {
"choices": [{
"finish_reason": "tool_calls",
"delta": {
"content": None,
"tool_calls": [{
"index": 0,
"id": "call_1",
"function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'},
"extra_content": GEMINI_EXTRA,
}],
},
}],
}
result = OpenAICompatProvider._parse_chunks([chunk])
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.extra_content == GEMINI_EXTRA
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
# ── Model switching: stale extras shouldn't break other providers ─────
def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None:
"""When switching from Gemini to OpenAI, extra_content inside tool_calls
should survive message sanitization (it lives inside the tool_call dict,
not at message level, so it bypasses _ALLOWED_MSG_KEYS filtering)."""
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
messages = [{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_1",
"type": "function",
"function": {"name": "fn", "arguments": "{}"},
"extra_content": GEMINI_EXTRA,
}],
}]
sanitized = provider._sanitize_messages(messages)
assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
@@ -0,0 +1,289 @@
import asyncio
import pytest
from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class DummyProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None:
provider = DummyProvider([])
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
interval_s=9999,
enabled=True,
)
await service.start()
first_task = service._task
await service.start()
assert service._task is first_task
service.stop()
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None:
provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])])
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
)
action, tasks = await service._decide("heartbeat content")
assert action == "skip"
assert tasks == ""
@pytest.mark.asyncio
async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None:
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check open tasks"},
)
],
)
])
called_with: list[str] = []
async def _on_execute(tasks: str) -> str:
called_with.append(tasks)
return "done"
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
)
result = await service.trigger_now()
assert result == "done"
assert called_with == ["check open tasks"]
@pytest.mark.asyncio
async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "skip"},
)
],
)
])
async def _on_execute(tasks: str) -> str:
return tasks
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
)
assert await service.trigger_now() is None
@pytest.mark.asyncio
async def test_tick_notifies_when_evaluator_says_yes(tmp_path, monkeypatch) -> None:
"""Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=notify -> on_notify called."""
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check deployments", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check deployments"},
)
],
),
])
executed: list[str] = []
notified: list[str] = []
async def _on_execute(tasks: str) -> str:
executed.append(tasks)
return "deployment failed on staging"
async def _on_notify(response: str) -> None:
notified.append(response)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
on_notify=_on_notify,
)
async def _eval_notify(*a, **kw):
return True
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_notify)
await service._tick()
assert executed == ["check deployments"]
assert notified == ["deployment failed on staging"]
@pytest.mark.asyncio
async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) -> None:
"""Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=silent -> on_notify NOT called."""
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check status", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check status"},
)
],
),
])
executed: list[str] = []
notified: list[str] = []
async def _on_execute(tasks: str) -> str:
executed.append(tasks)
return "everything is fine, no issues"
async def _on_notify(response: str) -> None:
notified.append(response)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
on_notify=_on_notify,
)
async def _eval_silent(*a, **kw):
return False
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_silent)
await service._tick()
assert executed == ["check status"]
assert notified == []
@pytest.mark.asyncio
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
provider = DummyProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check open tasks"},
)
],
),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
)
action, tasks = await service._decide("heartbeat content")
assert action == "run"
assert tasks == "check open tasks"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_decide_prompt_includes_current_time(tmp_path) -> None:
"""Phase 1 user prompt must contain current time so the LLM can judge task urgency."""
captured_messages: list[dict] = []
class CapturingProvider(LLMProvider):
async def chat(self, *, messages=None, **kwargs) -> LLMResponse:
if messages:
captured_messages.extend(messages)
return LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1", name="heartbeat",
arguments={"action": "skip"},
)
],
)
def get_default_model(self) -> str:
return "test-model"
service = HeartbeatService(
workspace=tmp_path,
provider=CapturingProvider(),
model="test-model",
)
await service._decide("- [ ] check servers at 10:00 UTC")
user_msg = captured_messages[1]
assert user_msg["role"] == "user"
assert "Current Time:" in user_msg["content"]
@@ -0,0 +1,196 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
import nanobot.agent.memory as memory_module
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
from nanobot.providers.base import GenerationSettings
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.generation = GenerationSettings(max_tokens=0)
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
_response = LLMResponse(content="ok", tool_calls=[])
provider.chat_with_retry = AsyncMock(return_value=_response)
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=context_window_tokens,
)
loop.tools.get_definitions = MagicMock(return_value=[])
loop.memory_consolidator._SAFETY_BUFFER = 0
return loop
@pytest.mark.asyncio
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
@pytest.mark.asyncio
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
await loop.process_direct("hello", session_key="cli:test")
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
@pytest.mark.asyncio
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
]
loop.sessions.save(session)
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
assert session.last_consolidated == 4
@pytest.mark.asyncio
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (300, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
"""Once triggered, consolidation should continue until it drops below half threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (150, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
"""Verify preflight consolidation runs before the LLM call in process_direct."""
order: list[str] = []
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
async def track_consolidate(messages):
order.append("consolidate")
return True
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
async def track_llm(*args, **kwargs):
order.append("llm")
return LLMResponse(content="ok", tool_calls=[])
loop.provider.chat_with_retry = track_llm
loop.provider.chat_stream_with_retry = track_llm
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
return (1000 if call_count[0] <= 1 else 80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
assert "consolidate" in order
assert "llm" in order
assert order.index("consolidate") < order.index("llm")
@@ -0,0 +1,27 @@
from pathlib import Path
from unittest.mock import MagicMock
from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools.cron import CronTool
from nanobot.bus.queue import MessageBus
from nanobot.cron.service import CronService
def test_agent_loop_registers_cron_tool_with_configured_timezone(tmp_path: Path) -> None:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
cron_service=CronService(tmp_path / "cron" / "jobs.json"),
timezone="Asia/Shanghai",
)
cron_tool = loop.tools.get("cron")
assert isinstance(cron_tool, CronTool)
assert cron_tool._default_timezone == "Asia/Shanghai"
@@ -0,0 +1,74 @@
from nanobot.agent.context import ContextBuilder
from nanobot.agent.loop import AgentLoop
from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
return loop
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
loop = _mk_loop()
session = Session(key="test:runtime-only")
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
loop._save_turn(
session,
[{"role": "user", "content": [{"type": "text", "text": runtime}]}],
skip=0,
)
assert session.messages == []
def test_save_turn_keeps_image_placeholder_with_path_after_runtime_strip() -> None:
loop = _mk_loop()
session = Session(key="test:image")
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
loop._save_turn(
session,
[{
"role": "user",
"content": [
{"type": "text", "text": runtime},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}, "_meta": {"path": "/media/feishu/photo.jpg"}},
],
}],
skip=0,
)
assert session.messages[0]["content"] == [{"type": "text", "text": "[image: /media/feishu/photo.jpg]"}]
def test_save_turn_keeps_image_placeholder_without_meta() -> None:
loop = _mk_loop()
session = Session(key="test:image-no-meta")
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
loop._save_turn(
session,
[{
"role": "user",
"content": [
{"type": "text", "text": runtime},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
}],
skip=0,
)
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
def test_save_turn_keeps_tool_results_under_16k() -> None:
loop = _mk_loop()
session = Session(key="test:tool-result")
content = "x" * 12_000
loop._save_turn(
session,
[{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
skip=0,
)
assert session.messages[0]["content"] == content
@@ -0,0 +1,478 @@
"""Test MemoryStore.consolidate() handles non-string tool call arguments.
Regression test for https://github.com/HKUDS/nanobot/issues/1042
When memory consolidation receives dict values instead of strings from the LLM
tool call response, it should serialize them to JSON instead of raising TypeError.
"""
import json
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
def _make_messages(message_count: int = 30):
"""Create a list of mock messages."""
return [
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
for i in range(message_count)
]
def _make_tool_response(history_entry, memory_update):
"""Create an LLMResponse with a save_memory tool call."""
return LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={
"history_entry": history_entry,
"memory_update": memory_update,
},
)
],
)
class ScriptedProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
class TestMemoryConsolidationTypeHandling:
"""Test that consolidation handles various argument types correctly."""
@pytest.mark.asyncio
async def test_string_arguments_work(self, tmp_path: Path) -> None:
"""Normal case: LLM returns string arguments."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
)
)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert store.history_file.exists()
assert "[2026-01-01] User discussed testing." in store.history_file.read_text()
assert "User likes testing." in store.memory_file.read_text()
@pytest.mark.asyncio
async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None:
"""Issue #1042: LLM returns dict instead of string — must not raise TypeError."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=_make_tool_response(
history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."},
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
)
)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert store.history_file.exists()
history_content = store.history_file.read_text()
parsed = json.loads(history_content.strip())
assert parsed["summary"] == "User discussed testing."
memory_content = store.memory_file.read_text()
parsed_mem = json.loads(memory_content)
assert "User likes testing" in parsed_mem["facts"]
@pytest.mark.asyncio
async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None:
"""Some providers return arguments as a JSON string instead of parsed dict."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=json.dumps({
"history_entry": "[2026-01-01] User discussed testing.",
"memory_update": "# Memory\nUser likes testing.",
}),
)
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert "User discussed testing." in store.history_file.read_text()
@pytest.mark.asyncio
async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None:
"""When LLM doesn't use the save_memory tool, return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when the selected chunk is empty."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = provider.chat
messages: list[dict] = []
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat.assert_not_called()
@pytest.mark.asyncio
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
"""Some providers return arguments as a list - extract first element if it's a dict."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=[{
"history_entry": "[2026-01-01] User discussed testing.",
"memory_update": "# Memory\nUser likes testing.",
}],
)
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert "User discussed testing." in store.history_file.read_text()
assert "User likes testing." in store.memory_file.read_text()
@pytest.mark.asyncio
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
"""Empty list arguments should return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=[],
)
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@pytest.mark.asyncio
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
"""List with non-dict content should return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=["string", "content"],
)
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@pytest.mark.asyncio
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not persist partial results when required fields are missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"memory_update": "# Memory\nOnly memory update"},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not append history if memory_update is missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"history_entry": "[2026-01-01] Partial output."},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Null required fields should be rejected before persistence."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=None,
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Empty history entries should be rejected to avoid blank archival records."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=" ",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
store = MemoryStore(tmp_path)
provider = ScriptedProvider([
LLMResponse(content="503 server error", finish_reason="error"),
_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
),
])
messages = _make_messages(message_count=60)
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
"""Consolidation no longer passes generation params — the provider owns them."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat_with_retry.assert_awaited_once()
_, kwargs = provider.chat_with_retry.await_args
assert kwargs["model"] == "test-model"
assert "temperature" not in kwargs
assert "max_tokens" not in kwargs
assert "reasoning_effort" not in kwargs
@pytest.mark.asyncio
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error calling LLM: BadRequestError: "
"The tool_choice parameter does not support being set to required or object",
finish_reason="error",
tool_calls=[],
)
ok_resp = _make_tool_response(
history_entry="[2026-01-01] Fallback worked.",
memory_update="# Memory\nFallback OK.",
)
call_log: list[dict] = []
async def _tracking_chat(**kwargs):
call_log.append(kwargs)
return error_resp if len(call_log) == 1 else ok_resp
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert len(call_log) == 2
assert isinstance(call_log[0]["tool_choice"], dict)
assert call_log[1]["tool_choice"] == "auto"
assert "Fallback worked." in store.history_file.read_text()
@pytest.mark.asyncio
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
"""Forced rejected, auto retry also produces no tool call -> return False."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error: tool_choice must be none or auto",
finish_reason="error",
tool_calls=[],
)
no_tool_resp = LLMResponse(
content="Here is a summary.",
finish_reason="stop",
tool_calls=[],
)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
"""After 3 consecutive failures, raw-archive messages and return True."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
messages = _make_messages(message_count=10)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is True
assert store.history_file.exists()
content = store.history_file.read_text()
assert "[RAW]" in content
assert "10 messages" in content
assert "msg0" in content
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
"""A successful consolidation resets the failure counter."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
ok_resp = _make_tool_response(
history_entry="[2026-01-01] OK.",
memory_update="# Memory\nOK.",
)
messages = _make_messages(message_count=10)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 2
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
assert await store.consolidate(messages, provider, "m") is True
assert store._consecutive_failures == 0
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 1
@@ -0,0 +1,495 @@
"""Unit tests for onboard core logic functions.
These tests focus on the business logic behind the onboard wizard,
without testing the interactive UI components.
"""
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any, cast
import pytest
from pydantic import BaseModel, Field
from nanobot.cli import onboard as onboard_wizard
# Import functions to test
from nanobot.cli.commands import _merge_missing_defaults
from nanobot.cli.onboard import (
_BACK_PRESSED,
_configure_pydantic_model,
_format_value,
_get_field_display_name,
_get_field_type_info,
run_onboard,
)
from nanobot.config.schema import Config
from nanobot.utils.helpers import sync_workspace_templates
class TestMergeMissingDefaults:
"""Tests for _merge_missing_defaults recursive config merging."""
def test_adds_missing_top_level_keys(self):
existing = {"a": 1}
defaults = {"a": 1, "b": 2, "c": 3}
result = _merge_missing_defaults(existing, defaults)
assert result == {"a": 1, "b": 2, "c": 3}
def test_preserves_existing_values(self):
existing = {"a": "custom_value"}
defaults = {"a": "default_value"}
result = _merge_missing_defaults(existing, defaults)
assert result == {"a": "custom_value"}
def test_merges_nested_dicts_recursively(self):
existing = {
"level1": {
"level2": {
"existing": "kept",
}
}
}
defaults = {
"level1": {
"level2": {
"existing": "replaced",
"added": "new",
},
"level2b": "also_new",
}
}
result = _merge_missing_defaults(existing, defaults)
assert result == {
"level1": {
"level2": {
"existing": "kept",
"added": "new",
},
"level2b": "also_new",
}
}
def test_returns_existing_if_not_dict(self):
assert _merge_missing_defaults("string", {"a": 1}) == "string"
assert _merge_missing_defaults([1, 2, 3], {"a": 1}) == [1, 2, 3]
assert _merge_missing_defaults(None, {"a": 1}) is None
assert _merge_missing_defaults(42, {"a": 1}) == 42
def test_returns_existing_if_defaults_not_dict(self):
assert _merge_missing_defaults({"a": 1}, "string") == {"a": 1}
assert _merge_missing_defaults({"a": 1}, None) == {"a": 1}
def test_handles_empty_dicts(self):
assert _merge_missing_defaults({}, {"a": 1}) == {"a": 1}
assert _merge_missing_defaults({"a": 1}, {}) == {"a": 1}
assert _merge_missing_defaults({}, {}) == {}
def test_backfills_channel_config(self):
"""Real-world scenario: backfill missing channel fields."""
existing_channel = {
"enabled": False,
"appId": "",
"secret": "",
}
default_channel = {
"enabled": False,
"appId": "",
"secret": "",
"msgFormat": "plain",
"allowFrom": [],
}
result = _merge_missing_defaults(existing_channel, default_channel)
assert result["msgFormat"] == "plain"
assert result["allowFrom"] == []
class TestGetFieldTypeInfo:
"""Tests for _get_field_type_info type extraction."""
def test_extracts_str_type(self):
class Model(BaseModel):
field: str
type_name, inner = _get_field_type_info(Model.model_fields["field"])
assert type_name == "str"
assert inner is None
def test_extracts_int_type(self):
class Model(BaseModel):
count: int
type_name, inner = _get_field_type_info(Model.model_fields["count"])
assert type_name == "int"
assert inner is None
def test_extracts_bool_type(self):
class Model(BaseModel):
enabled: bool
type_name, inner = _get_field_type_info(Model.model_fields["enabled"])
assert type_name == "bool"
assert inner is None
def test_extracts_float_type(self):
class Model(BaseModel):
ratio: float
type_name, inner = _get_field_type_info(Model.model_fields["ratio"])
assert type_name == "float"
assert inner is None
def test_extracts_list_type_with_item_type(self):
class Model(BaseModel):
items: list[str]
type_name, inner = _get_field_type_info(Model.model_fields["items"])
assert type_name == "list"
assert inner is str
def test_extracts_list_type_without_item_type(self):
# Plain list without type param falls back to str
class Model(BaseModel):
items: list # type: ignore
# Plain list annotation doesn't match list check, returns str
type_name, inner = _get_field_type_info(Model.model_fields["items"])
assert type_name == "str" # Falls back to str for untyped list
assert inner is None
def test_extracts_dict_type(self):
# Plain dict without type param falls back to str
class Model(BaseModel):
data: dict # type: ignore
# Plain dict annotation doesn't match dict check, returns str
type_name, inner = _get_field_type_info(Model.model_fields["data"])
assert type_name == "str" # Falls back to str for untyped dict
assert inner is None
def test_extracts_optional_type(self):
class Model(BaseModel):
optional: str | None = None
type_name, inner = _get_field_type_info(Model.model_fields["optional"])
# Should unwrap Optional and get str
assert type_name == "str"
assert inner is None
def test_extracts_nested_model_type(self):
class Inner(BaseModel):
x: int
class Outer(BaseModel):
nested: Inner
type_name, inner = _get_field_type_info(Outer.model_fields["nested"])
assert type_name == "model"
assert inner is Inner
def test_handles_none_annotation(self):
"""Field with None annotation defaults to str."""
class Model(BaseModel):
field: Any = None
# Create a mock field_info with None annotation
field_info = SimpleNamespace(annotation=None)
type_name, inner = _get_field_type_info(field_info)
assert type_name == "str"
assert inner is None
class TestGetFieldDisplayName:
"""Tests for _get_field_display_name human-readable name generation."""
def test_uses_description_if_present(self):
class Model(BaseModel):
api_key: str = Field(description="API Key for authentication")
name = _get_field_display_name("api_key", Model.model_fields["api_key"])
assert name == "API Key for authentication"
def test_converts_snake_case_to_title(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("user_name", field_info)
assert name == "User Name"
def test_adds_url_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("api_url", field_info)
# Title case: "Api Url"
assert "Url" in name and "Api" in name
def test_adds_path_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("file_path", field_info)
assert "Path" in name and "File" in name
def test_adds_id_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("user_id", field_info)
# Title case: "User Id"
assert "Id" in name and "User" in name
def test_adds_key_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("api_key", field_info)
assert "Key" in name and "Api" in name
def test_adds_token_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("auth_token", field_info)
assert "Token" in name and "Auth" in name
def test_adds_seconds_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("timeout_s", field_info)
# Contains "(Seconds)" with title case
assert "(Seconds)" in name or "(seconds)" in name
def test_adds_ms_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("delay_ms", field_info)
# Contains "(Ms)" or "(ms)"
assert "(Ms)" in name or "(ms)" in name
class TestFormatValue:
"""Tests for _format_value display formatting."""
def test_formats_none_as_not_set(self):
assert "not set" in _format_value(None)
def test_formats_empty_string_as_not_set(self):
assert "not set" in _format_value("")
def test_formats_empty_dict_as_not_set(self):
assert "not set" in _format_value({})
def test_formats_empty_list_as_not_set(self):
assert "not set" in _format_value([])
def test_formats_string_value(self):
result = _format_value("hello")
assert "hello" in result
def test_formats_list_value(self):
result = _format_value(["a", "b"])
assert "a" in result or "b" in result
def test_formats_dict_value(self):
result = _format_value({"key": "value"})
assert "key" in result or "value" in result
def test_formats_int_value(self):
result = _format_value(42)
assert "42" in result
def test_formats_bool_true(self):
result = _format_value(True)
assert "true" in result.lower() or "" in result
def test_formats_bool_false(self):
result = _format_value(False)
assert "false" in result.lower() or "" in result
class TestSyncWorkspaceTemplates:
"""Tests for sync_workspace_templates file synchronization."""
def test_creates_missing_files(self, tmp_path):
"""Should create template files that don't exist."""
workspace = tmp_path / "workspace"
added = sync_workspace_templates(workspace, silent=True)
# Check that some files were created
assert isinstance(added, list)
# The actual files depend on the templates directory
def test_does_not_overwrite_existing_files(self, tmp_path):
"""Should not overwrite files that already exist."""
workspace = tmp_path / "workspace"
workspace.mkdir(parents=True)
(workspace / "AGENTS.md").write_text("existing content")
sync_workspace_templates(workspace, silent=True)
# Existing file should not be changed
content = (workspace / "AGENTS.md").read_text()
assert content == "existing content"
def test_creates_memory_directory(self, tmp_path):
"""Should create memory directory structure."""
workspace = tmp_path / "workspace"
sync_workspace_templates(workspace, silent=True)
assert (workspace / "memory").exists() or (workspace / "skills").exists()
def test_returns_list_of_added_files(self, tmp_path):
"""Should return list of relative paths for added files."""
workspace = tmp_path / "workspace"
added = sync_workspace_templates(workspace, silent=True)
assert isinstance(added, list)
# All paths should be relative to workspace
for path in added:
assert not Path(path).is_absolute()
class TestProviderChannelInfo:
"""Tests for provider and channel info retrieval."""
def test_get_provider_names_returns_dict(self):
from nanobot.cli.onboard import _get_provider_names
names = _get_provider_names()
assert isinstance(names, dict)
assert len(names) > 0
# Should include common providers
assert "openai" in names or "anthropic" in names
assert "openai_codex" not in names
assert "github_copilot" not in names
def test_get_channel_names_returns_dict(self):
from nanobot.cli.onboard import _get_channel_names
names = _get_channel_names()
assert isinstance(names, dict)
# Should include at least some channels
assert len(names) >= 0
def test_get_provider_info_returns_valid_structure(self):
from nanobot.cli.onboard import _get_provider_info
info = _get_provider_info()
assert isinstance(info, dict)
# Each value should be a tuple with expected structure
for provider_name, value in info.items():
assert isinstance(value, tuple)
assert len(value) == 4 # (display_name, needs_api_key, needs_api_base, env_var)
class _SimpleDraftModel(BaseModel):
api_key: str = ""
class _NestedDraftModel(BaseModel):
api_key: str = ""
class _OuterDraftModel(BaseModel):
nested: _NestedDraftModel = Field(default_factory=_NestedDraftModel)
class TestConfigurePydanticModelDrafts:
@staticmethod
def _patch_prompt_helpers(monkeypatch, tokens, text_value="secret"):
sequence = iter(tokens)
def fake_select(_prompt, choices, default=None):
token = next(sequence)
if token == "first":
return choices[0]
if token == "done":
return "[Done]"
if token == "back":
return _BACK_PRESSED
return token
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
onboard_wizard, "_input_with_existing", lambda *_args, **_kwargs: text_value
)
def test_discarding_section_keeps_original_model_unchanged(self, monkeypatch):
model = _SimpleDraftModel()
self._patch_prompt_helpers(monkeypatch, ["first", "back"])
result = _configure_pydantic_model(model, "Simple")
assert result is None
assert model.api_key == ""
def test_completing_section_returns_updated_draft(self, monkeypatch):
model = _SimpleDraftModel()
self._patch_prompt_helpers(monkeypatch, ["first", "done"])
result = _configure_pydantic_model(model, "Simple")
assert result is not None
updated = cast(_SimpleDraftModel, result)
assert updated.api_key == "secret"
assert model.api_key == ""
def test_nested_section_back_discards_nested_edits(self, monkeypatch):
model = _OuterDraftModel()
self._patch_prompt_helpers(monkeypatch, ["first", "first", "back", "done"])
result = _configure_pydantic_model(model, "Outer")
assert result is not None
updated = cast(_OuterDraftModel, result)
assert updated.nested.api_key == ""
assert model.nested.api_key == ""
def test_nested_section_done_commits_nested_edits(self, monkeypatch):
model = _OuterDraftModel()
self._patch_prompt_helpers(monkeypatch, ["first", "first", "done", "done"])
result = _configure_pydantic_model(model, "Outer")
assert result is not None
updated = cast(_OuterDraftModel, result)
assert updated.nested.api_key == "secret"
assert model.nested.api_key == ""
class TestRunOnboardExitBehavior:
def test_main_menu_interrupt_can_discard_unsaved_session_changes(self, monkeypatch):
initial_config = Config()
responses = iter(
[
"[A] Agent Settings",
KeyboardInterrupt(),
"[X] Exit Without Saving",
]
)
class FakePrompt:
def __init__(self, response):
self.response = response
def ask(self):
if isinstance(self.response, BaseException):
raise self.response
return self.response
def fake_select(*_args, **_kwargs):
return FakePrompt(next(responses))
def fake_configure_general_settings(config, section):
if section == "Agent Settings":
config.agents.defaults.model = "test/provider-model"
monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None)
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
monkeypatch.setattr(onboard_wizard, "_configure_general_settings", fake_configure_general_settings)
result = run_onboard(initial_config=initial_config)
assert result.should_save is False
assert result.config.model_dump(by_alias=True) == initial_config.model_dump(by_alias=True)
+335
View File
@@ -0,0 +1,335 @@
"""Tests for the shared agent runner and its integration contracts."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.providers.base import LLMResponse, ToolCallRequest
def _make_loop(tmp_path):
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
with patch("nanobot.agent.loop.ContextBuilder"), \
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
return loop
@pytest.mark.asyncio
async def test_runner_preserves_reasoning_fields_and_tool_results():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
reasoning_content="hidden reasoning",
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
usage={"prompt_tokens": 5, "completion_tokens": 3},
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[
{"role": "system", "content": "system"},
{"role": "user", "content": "do task"},
],
tools=tools,
model="test-model",
max_iterations=3,
))
assert result.final_content == "done"
assert result.tools_used == ["list_dir"]
assert result.tool_events == [
{"name": "list_dir", "status": "ok", "detail": "tool result"}
]
assistant_messages = [
msg for msg in captured_second_call
if msg.get("role") == "assistant" and msg.get("tool_calls")
]
assert len(assistant_messages) == 1
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
assert any(
msg.get("role") == "tool" and msg.get("content") == "tool result"
for msg in captured_second_call
)
@pytest.mark.asyncio
async def test_runner_calls_hooks_in_order():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
events: list[tuple] = []
async def chat_with_retry(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
)
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
class RecordingHook(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
events.append(("before_iteration", context.iteration))
async def before_execute_tools(self, context: AgentHookContext) -> None:
events.append((
"before_execute_tools",
context.iteration,
[tc.name for tc in context.tool_calls],
))
async def after_iteration(self, context: AgentHookContext) -> None:
events.append((
"after_iteration",
context.iteration,
context.final_content,
list(context.tool_results),
list(context.tool_events),
context.stop_reason,
))
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
events.append(("finalize_content", context.iteration, content))
return content.upper() if content else content
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=3,
hook=RecordingHook(),
))
assert result.final_content == "DONE"
assert events == [
("before_iteration", 0),
("before_execute_tools", 0, ["list_dir"]),
(
"after_iteration",
0,
None,
["tool result"],
[{"name": "list_dir", "status": "ok", "detail": "tool result"}],
None,
),
("before_iteration", 1),
("finalize_content", 1, "done"),
("after_iteration", 1, "DONE", [], [], "completed"),
]
@pytest.mark.asyncio
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
streamed: list[str] = []
endings: list[bool] = []
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
await on_content_delta("he")
await on_content_delta("llo")
return LLMResponse(content="hello", tool_calls=[], usage={})
provider.chat_stream_with_retry = chat_stream_with_retry
provider.chat_with_retry = AsyncMock()
tools = MagicMock()
tools.get_definitions.return_value = []
class StreamingHook(AgentHook):
def wants_streaming(self) -> bool:
return True
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
streamed.append(delta)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
endings.append(resuming)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
hook=StreamingHook(),
))
assert result.final_content == "hello"
assert streamed == ["he", "llo"]
assert endings == [False]
provider.chat_with_retry.assert_not_awaited()
@pytest.mark.asyncio
async def test_runner_returns_max_iterations_fallback():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="still working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
))
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=2,
))
assert result.stop_reason == "max_iterations"
assert result.final_content == (
"I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps."
)
@pytest.mark.asyncio
async def test_runner_returns_structured_tool_error():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
))
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=2,
fail_on_tool_error=True,
))
assert result.stop_reason == "tool_error"
assert result.error == "Error: RuntimeError: boom"
assert result.tool_events == [
{"name": "list_dir", "status": "error", "detail": "boom"}
]
@pytest.mark.asyncio
async def test_loop_max_iterations_message_stays_stable(tmp_path):
loop = _make_loop(tmp_path)
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2
final_content, _, _ = await loop._run_agent_loop([])
assert final_content == (
"I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps."
)
@pytest.mark.asyncio
async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path):
loop = _make_loop(tmp_path)
deltas: list[str] = []
endings: list[bool] = []
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
await on_content_delta("<think>hidden")
await on_content_delta("</think>Hello")
return LLMResponse(content="<think>hidden</think>Hello", tool_calls=[], usage={})
loop.provider.chat_stream_with_retry = chat_stream_with_retry
async def on_stream(delta: str) -> None:
deltas.append(delta)
async def on_stream_end(*, resuming: bool = False) -> None:
endings.append(resuming)
final_content, _, _ = await loop._run_agent_loop(
[],
on_stream=on_stream,
on_stream_end=on_stream_end,
)
assert final_content == "Hello"
assert deltas == ["Hello"]
assert endings == [False]
@pytest.mark.asyncio
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr._announce_result = AsyncMock()
async def fake_execute(self, name, arguments):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
mgr._announce_result.assert_awaited_once()
args = mgr._announce_result.await_args.args
assert args[3] == "Task completed but no final response was generated."
assert args[5] == "ok"
@@ -0,0 +1,198 @@
from nanobot.session.manager import Session
def _assert_no_orphans(history: list[dict]) -> None:
"""Assert every tool result in history has a matching assistant tool_call."""
declared = {
tc["id"]
for m in history if m.get("role") == "assistant"
for tc in (m.get("tool_calls") or [])
}
orphans = [
m.get("tool_call_id") for m in history
if m.get("role") == "tool" and m.get("tool_call_id") not in declared
]
assert orphans == [], f"orphan tool_call_ids: {orphans}"
def _tool_turn(prefix: str, idx: int) -> list[dict]:
"""Helper: one assistant with 2 tool_calls + 2 tool results."""
return [
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": f"{prefix}_{idx}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
{"id": f"{prefix}_{idx}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
],
},
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_a", "name": "x", "content": "ok"},
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_b", "name": "y", "content": "ok"},
]
# --- Original regression test (from PR 2075) ---
def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls():
session = Session(key="telegram:test")
session.messages.append({"role": "user", "content": "old turn"})
for i in range(20):
session.messages.extend(_tool_turn("old", i))
session.messages.append({"role": "user", "content": "problem turn"})
for i in range(25):
session.messages.extend(_tool_turn("cur", i))
session.messages.append({"role": "user", "content": "new telegram question"})
history = session.get_history(max_messages=100)
_assert_no_orphans(history)
# --- Positive test: legitimate pairs survive trimming ---
def test_legitimate_tool_pairs_preserved_after_trim():
"""Complete tool-call groups within the window must not be dropped."""
session = Session(key="test:positive")
session.messages.append({"role": "user", "content": "hello"})
for i in range(5):
session.messages.extend(_tool_turn("ok", i))
session.messages.append({"role": "assistant", "content": "done"})
history = session.get_history(max_messages=500)
_assert_no_orphans(history)
tool_ids = [m["tool_call_id"] for m in history if m.get("role") == "tool"]
assert len(tool_ids) == 10
assert history[0]["role"] == "user"
def test_retain_recent_legal_suffix_keeps_recent_messages():
session = Session(key="test:trim")
for i in range(10):
session.messages.append({"role": "user", "content": f"msg{i}"})
session.retain_recent_legal_suffix(4)
assert len(session.messages) == 4
assert session.messages[0]["content"] == "msg6"
assert session.messages[-1]["content"] == "msg9"
def test_retain_recent_legal_suffix_adjusts_last_consolidated():
session = Session(key="test:trim-cons")
for i in range(10):
session.messages.append({"role": "user", "content": f"msg{i}"})
session.last_consolidated = 7
session.retain_recent_legal_suffix(4)
assert len(session.messages) == 4
assert session.last_consolidated == 1
def test_retain_recent_legal_suffix_zero_clears_session():
session = Session(key="test:trim-zero")
for i in range(10):
session.messages.append({"role": "user", "content": f"msg{i}"})
session.last_consolidated = 5
session.retain_recent_legal_suffix(0)
assert session.messages == []
assert session.last_consolidated == 0
def test_retain_recent_legal_suffix_keeps_legal_tool_boundary():
session = Session(key="test:trim-tools")
session.messages.append({"role": "user", "content": "old"})
session.messages.extend(_tool_turn("old", 0))
session.messages.append({"role": "user", "content": "keep"})
session.messages.extend(_tool_turn("keep", 0))
session.messages.append({"role": "assistant", "content": "done"})
session.retain_recent_legal_suffix(4)
history = session.get_history(max_messages=500)
_assert_no_orphans(history)
assert history[0]["role"] == "user"
assert history[0]["content"] == "keep"
# --- last_consolidated > 0 ---
def test_orphan_trim_with_last_consolidated():
"""Orphan trimming works correctly when session is partially consolidated."""
session = Session(key="test:consolidated")
for i in range(10):
session.messages.append({"role": "user", "content": f"old {i}"})
session.messages.extend(_tool_turn("cons", i))
session.last_consolidated = 30
session.messages.append({"role": "user", "content": "recent"})
for i in range(15):
session.messages.extend(_tool_turn("new", i))
session.messages.append({"role": "user", "content": "latest"})
history = session.get_history(max_messages=20)
_assert_no_orphans(history)
assert all(m.get("role") != "tool" or m["tool_call_id"].startswith("new_") for m in history)
# --- Edge: no tool messages at all ---
def test_no_tool_messages_unchanged():
session = Session(key="test:plain")
for i in range(5):
session.messages.append({"role": "user", "content": f"q{i}"})
session.messages.append({"role": "assistant", "content": f"a{i}"})
history = session.get_history(max_messages=6)
assert len(history) == 6
_assert_no_orphans(history)
# --- Edge: all leading messages are orphan tool results ---
def test_all_orphan_prefix_stripped():
"""If the window starts with orphan tool results and nothing else, they're all dropped."""
session = Session(key="test:all-orphan")
session.messages.append({"role": "tool", "tool_call_id": "gone_1", "name": "x", "content": "ok"})
session.messages.append({"role": "tool", "tool_call_id": "gone_2", "name": "y", "content": "ok"})
session.messages.append({"role": "user", "content": "fresh start"})
session.messages.append({"role": "assistant", "content": "hi"})
history = session.get_history(max_messages=500)
_assert_no_orphans(history)
assert history[0]["role"] == "user"
assert len(history) == 2
# --- Edge: empty session ---
def test_empty_session_history():
session = Session(key="test:empty")
history = session.get_history(max_messages=500)
assert history == []
# --- Window cuts mid-group: assistant present but some tool results orphaned ---
def test_window_cuts_mid_tool_group():
"""If the window starts between an assistant's tool results, the partial group is trimmed."""
session = Session(key="test:mid-cut")
session.messages.append({"role": "user", "content": "setup"})
session.messages.append({
"role": "assistant", "content": None,
"tool_calls": [
{"id": "split_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
{"id": "split_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
],
})
session.messages.append({"role": "tool", "tool_call_id": "split_a", "name": "x", "content": "ok"})
session.messages.append({"role": "tool", "tool_call_id": "split_b", "name": "y", "content": "ok"})
session.messages.append({"role": "user", "content": "next"})
session.messages.extend(_tool_turn("intact", 0))
session.messages.append({"role": "assistant", "content": "final"})
# Window of 6 should cut off the "setup" user msg and the assistant with split_a/split_b,
# leaving orphan tool results for split_a at the front.
history = session.get_history(max_messages=6)
_assert_no_orphans(history)
@@ -0,0 +1,127 @@
import importlib
import shutil
import sys
import zipfile
from pathlib import Path
SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve()
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
init_skill = importlib.import_module("init_skill")
package_skill = importlib.import_module("package_skill")
quick_validate = importlib.import_module("quick_validate")
def test_init_skill_creates_expected_files(tmp_path: Path) -> None:
skill_dir = init_skill.init_skill(
"demo-skill",
tmp_path,
["scripts", "references", "assets"],
include_examples=True,
)
assert skill_dir == tmp_path / "demo-skill"
assert (skill_dir / "SKILL.md").exists()
assert (skill_dir / "scripts" / "example.py").exists()
assert (skill_dir / "references" / "api_reference.md").exists()
assert (skill_dir / "assets" / "example_asset.txt").exists()
def test_validate_skill_accepts_existing_skill_creator() -> None:
valid, message = quick_validate.validate_skill(
Path("nanobot/skills/skill-creator").resolve()
)
assert valid, message
def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None:
skill_dir = tmp_path / "placeholder-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: placeholder-skill\n"
'description: "[TODO: fill me in]"\n'
"---\n"
"# Placeholder\n",
encoding="utf-8",
)
valid, message = quick_validate.validate_skill(skill_dir)
assert not valid
assert "TODO placeholder" in message
def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None:
skill_dir = tmp_path / "bad-root-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: bad-root-skill\n"
"description: Valid description\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
(skill_dir / "README.md").write_text("extra\n", encoding="utf-8")
valid, message = quick_validate.validate_skill(skill_dir)
assert not valid
assert "Unexpected file or directory in skill root" in message
def test_package_skill_creates_archive(tmp_path: Path) -> None:
skill_dir = tmp_path / "package-me"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: package-me\n"
"description: Package this skill.\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
scripts_dir = skill_dir / "scripts"
scripts_dir.mkdir()
(scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8")
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
assert archive_path == (tmp_path / "dist" / "package-me.skill")
assert archive_path.exists()
with zipfile.ZipFile(archive_path, "r") as archive:
names = set(archive.namelist())
assert "package-me/SKILL.md" in names
assert "package-me/scripts/helper.py" in names
def test_package_skill_rejects_symlink(tmp_path: Path) -> None:
skill_dir = tmp_path / "symlink-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: symlink-skill\n"
"description: Reject symlinks during packaging.\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
scripts_dir = skill_dir / "scripts"
scripts_dir.mkdir()
target = tmp_path / "outside.txt"
target.write_text("secret\n", encoding="utf-8")
link = scripts_dir / "outside.txt"
try:
link.symlink_to(target)
except (OSError, NotImplementedError):
return
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
assert archive_path is None
assert not (tmp_path / "dist" / "symlink-skill.skill").exists()
+303
View File
@@ -0,0 +1,303 @@
"""Tests for /stop task cancellation."""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
def _make_loop(*, exec_config=None):
"""Create a minimal AgentLoop with mocked dependencies."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
workspace = MagicMock()
workspace.__truediv__ = MagicMock(return_value=MagicMock())
with patch("nanobot.agent.loop.ContextBuilder"), \
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace, exec_config=exec_config)
return loop, bus
class TestHandleStop:
@pytest.mark.asyncio
async def test_stop_no_active_task(self):
from nanobot.bus.events import InboundMessage
from nanobot.command.builtin import cmd_stop
from nanobot.command.router import CommandContext
loop, bus = _make_loop()
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw="/stop", loop=loop)
out = await cmd_stop(ctx)
assert "No active task" in out.content
@pytest.mark.asyncio
async def test_stop_cancels_active_task(self):
from nanobot.bus.events import InboundMessage
from nanobot.command.builtin import cmd_stop
from nanobot.command.router import CommandContext
loop, bus = _make_loop()
cancelled = asyncio.Event()
async def slow_task():
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
cancelled.set()
raise
task = asyncio.create_task(slow_task())
await asyncio.sleep(0)
loop._active_tasks["test:c1"] = [task]
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw="/stop", loop=loop)
out = await cmd_stop(ctx)
assert cancelled.is_set()
assert "stopped" in out.content.lower()
@pytest.mark.asyncio
async def test_stop_cancels_multiple_tasks(self):
from nanobot.bus.events import InboundMessage
from nanobot.command.builtin import cmd_stop
from nanobot.command.router import CommandContext
loop, bus = _make_loop()
events = [asyncio.Event(), asyncio.Event()]
async def slow(idx):
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
events[idx].set()
raise
tasks = [asyncio.create_task(slow(i)) for i in range(2)]
await asyncio.sleep(0)
loop._active_tasks["test:c1"] = tasks
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw="/stop", loop=loop)
out = await cmd_stop(ctx)
assert all(e.is_set() for e in events)
assert "2 task" in out.content
class TestDispatch:
def test_exec_tool_not_registered_when_disabled(self):
from nanobot.config.schema import ExecToolConfig
loop, _bus = _make_loop(exec_config=ExecToolConfig(enable=False))
assert loop.tools.get("exec") is None
@pytest.mark.asyncio
async def test_dispatch_processes_and_publishes(self):
from nanobot.bus.events import InboundMessage, OutboundMessage
loop, bus = _make_loop()
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="hello")
loop._process_message = AsyncMock(
return_value=OutboundMessage(channel="test", chat_id="c1", content="hi")
)
await loop._dispatch(msg)
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert out.content == "hi"
@pytest.mark.asyncio
async def test_processing_lock_serializes(self):
from nanobot.bus.events import InboundMessage, OutboundMessage
loop, bus = _make_loop()
order = []
async def mock_process(m, **kwargs):
order.append(f"start-{m.content}")
await asyncio.sleep(0.05)
order.append(f"end-{m.content}")
return OutboundMessage(channel="test", chat_id="c1", content=m.content)
loop._process_message = mock_process
msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a")
msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b")
t1 = asyncio.create_task(loop._dispatch(msg1))
t2 = asyncio.create_task(loop._dispatch(msg2))
await asyncio.gather(t1, t2)
assert order == ["start-a", "end-a", "start-b", "end-b"]
class TestSubagentCancellation:
@pytest.mark.asyncio
async def test_cancel_by_session(self):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
cancelled = asyncio.Event()
async def slow():
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
cancelled.set()
raise
task = asyncio.create_task(slow())
await asyncio.sleep(0)
mgr._running_tasks["sub-1"] = task
mgr._session_tasks["test:c1"] = {"sub-1"}
count = await mgr.cancel_by_session("test:c1")
assert count == 1
assert cancelled.is_set()
@pytest.mark.asyncio
async def test_cancel_by_session_no_tasks(self):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
assert await mgr.cancel_by_session("nonexistent") == 0
@pytest.mark.asyncio
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def scripted_chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
reasoning_content="hidden reasoning",
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = scripted_chat_with_retry
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
async def fake_execute(self, name, arguments):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
assistant_messages = [
msg for msg in captured_second_call
if msg.get("role") == "assistant" and msg.get("tool_calls")
]
assert len(assistant_messages) == 1
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
@pytest.mark.asyncio
async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr._announce_result = AsyncMock()
calls = {"n": 0}
async def fake_execute(self, name, arguments):
calls["n"] += 1
if calls["n"] == 1:
return "first result"
raise RuntimeError("boom")
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
mgr._announce_result.assert_awaited_once()
args = mgr._announce_result.await_args.args
assert "Completed steps:" in args[3]
assert "- list_dir: first result" in args[3]
assert "Failure:" in args[3]
assert "- list_dir: boom" in args[3]
assert args[5] == "error"
@pytest.mark.asyncio
async def test_cancel_by_session_cancels_running_subagent_tool(self, monkeypatch, tmp_path):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr._announce_result = AsyncMock()
started = asyncio.Event()
cancelled = asyncio.Event()
async def fake_execute(self, name, arguments):
started.set()
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
cancelled.set()
raise
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
task = asyncio.create_task(
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
)
mgr._running_tasks["sub-1"] = task
mgr._session_tasks["test:c1"] = {"sub-1"}
await started.wait()
count = await mgr.cancel_by_session("test:c1")
assert count == 1
assert cancelled.is_set()
assert task.cancelled()
mgr._announce_result.assert_not_awaited()
@@ -0,0 +1,25 @@
from types import SimpleNamespace
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
class _DummyChannel(BaseChannel):
name = "dummy"
async def start(self) -> None:
return None
async def stop(self) -> None:
return None
async def send(self, msg: OutboundMessage) -> None:
return None
def test_is_allowed_requires_exact_match() -> None:
channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
assert channel.is_allowed("allow@email.com") is True
assert channel.is_allowed("attacker|allow@email.com") is False
@@ -0,0 +1,298 @@
"""Tests for ChannelManager delta coalescing to reduce streaming latency."""
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.channels.manager import ChannelManager
from nanobot.config.schema import Config
class MockChannel(BaseChannel):
"""Mock channel for testing."""
name = "mock"
display_name = "Mock"
def __init__(self, config, bus):
super().__init__(config, bus)
self._send_delta_mock = AsyncMock()
self._send_mock = AsyncMock()
async def start(self):
pass
async def stop(self):
pass
async def send(self, msg):
"""Implement abstract method."""
return await self._send_mock(msg)
async def send_delta(self, chat_id, delta, metadata=None):
"""Override send_delta for testing."""
return await self._send_delta_mock(chat_id, delta, metadata)
@pytest.fixture
def config():
"""Create a minimal config for testing."""
return Config()
@pytest.fixture
def bus():
"""Create a message bus for testing."""
return MessageBus()
@pytest.fixture
def manager(config, bus):
"""Create a channel manager with a mock channel."""
manager = ChannelManager(config, bus)
manager.channels["mock"] = MockChannel({}, bus)
return manager
class TestDeltaCoalescing:
"""Tests for _stream_delta message coalescing."""
@pytest.mark.asyncio
async def test_single_delta_not_coalesced(self, manager, bus):
"""A single delta should be sent as-is."""
msg = OutboundMessage(
channel="mock",
chat_id="chat1",
content="Hello",
metadata={"_stream_delta": True},
)
await bus.publish_outbound(msg)
# Process one message
async def process_one():
try:
m = await asyncio.wait_for(bus.consume_outbound(), timeout=0.1)
if m.metadata.get("_stream_delta"):
m, pending = manager._coalesce_stream_deltas(m)
# Put pending back (none expected)
for p in pending:
await bus.publish_outbound(p)
channel = manager.channels.get(m.channel)
if channel:
await channel.send_delta(m.chat_id, m.content, m.metadata)
except asyncio.TimeoutError:
pass
await process_one()
manager.channels["mock"]._send_delta_mock.assert_called_once_with(
"chat1", "Hello", {"_stream_delta": True}
)
@pytest.mark.asyncio
async def test_multiple_deltas_coalesced(self, manager, bus):
"""Multiple consecutive deltas for same chat should be merged."""
# Put multiple deltas in queue
for text in ["Hello", " ", "world", "!"]:
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content=text,
metadata={"_stream_delta": True},
))
# Process using coalescing logic
first_msg = await bus.consume_outbound()
merged, pending = manager._coalesce_stream_deltas(first_msg)
# Should have merged all deltas
assert merged.content == "Hello world!"
assert merged.metadata.get("_stream_delta") is True
# No pending messages (all were coalesced)
assert len(pending) == 0
@pytest.mark.asyncio
async def test_deltas_different_chats_not_coalesced(self, manager, bus):
"""Deltas for different chats should not be merged."""
# Put deltas for different chats
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="Hello",
metadata={"_stream_delta": True},
))
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat2",
content="World",
metadata={"_stream_delta": True},
))
first_msg = await bus.consume_outbound()
merged, pending = manager._coalesce_stream_deltas(first_msg)
# First chat should not include second chat's content
assert merged.content == "Hello"
assert merged.chat_id == "chat1"
# Second chat should be in pending
assert len(pending) == 1
assert pending[0].chat_id == "chat2"
assert pending[0].content == "World"
@pytest.mark.asyncio
async def test_stream_end_terminates_coalescing(self, manager, bus):
"""_stream_end should stop coalescing and be included in final message."""
# Put deltas with stream_end at the end
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="Hello",
metadata={"_stream_delta": True},
))
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content=" world",
metadata={"_stream_delta": True, "_stream_end": True},
))
first_msg = await bus.consume_outbound()
merged, pending = manager._coalesce_stream_deltas(first_msg)
# Should have merged content
assert merged.content == "Hello world"
# Should have stream_end flag
assert merged.metadata.get("_stream_end") is True
# No pending
assert len(pending) == 0
@pytest.mark.asyncio
async def test_coalescing_stops_at_first_non_matching_boundary(self, manager, bus):
"""Only consecutive deltas should be merged; later deltas stay queued."""
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="Hello",
metadata={"_stream_delta": True, "_stream_id": "seg-1"},
))
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="",
metadata={"_stream_end": True, "_stream_id": "seg-1"},
))
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="world",
metadata={"_stream_delta": True, "_stream_id": "seg-2"},
))
first_msg = await bus.consume_outbound()
merged, pending = manager._coalesce_stream_deltas(first_msg)
assert merged.content == "Hello"
assert merged.metadata.get("_stream_end") is None
assert len(pending) == 1
assert pending[0].metadata.get("_stream_end") is True
assert pending[0].metadata.get("_stream_id") == "seg-1"
# The next stream segment must remain in queue order for later dispatch.
remaining = await bus.consume_outbound()
assert remaining.content == "world"
assert remaining.metadata.get("_stream_id") == "seg-2"
@pytest.mark.asyncio
async def test_non_delta_message_preserved(self, manager, bus):
"""Non-delta messages should be preserved in pending list."""
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="Delta",
metadata={"_stream_delta": True},
))
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="Final message",
metadata={}, # Not a delta
))
first_msg = await bus.consume_outbound()
merged, pending = manager._coalesce_stream_deltas(first_msg)
assert merged.content == "Delta"
assert len(pending) == 1
assert pending[0].content == "Final message"
assert pending[0].metadata.get("_stream_delta") is None
@pytest.mark.asyncio
async def test_empty_queue_stops_coalescing(self, manager, bus):
"""Coalescing should stop when queue is empty."""
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="Only message",
metadata={"_stream_delta": True},
))
first_msg = await bus.consume_outbound()
merged, pending = manager._coalesce_stream_deltas(first_msg)
assert merged.content == "Only message"
assert len(pending) == 0
class TestDispatchOutboundWithCoalescing:
"""Tests for the full _dispatch_outbound flow with coalescing."""
@pytest.mark.asyncio
async def test_dispatch_coalesces_and_processes_pending(self, manager, bus):
"""_dispatch_outbound should coalesce deltas and process pending messages."""
# Put multiple deltas followed by a regular message
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="A",
metadata={"_stream_delta": True},
))
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="B",
metadata={"_stream_delta": True},
))
await bus.publish_outbound(OutboundMessage(
channel="mock",
chat_id="chat1",
content="Final",
metadata={}, # Regular message
))
# Run one iteration of dispatch logic manually
pending = []
processed = []
# First iteration: should coalesce A+B
if pending:
msg = pending.pop(0)
else:
msg = await bus.consume_outbound()
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):
msg, extra_pending = manager._coalesce_stream_deltas(msg)
pending.extend(extra_pending)
channel = manager.channels.get(msg.channel)
if channel:
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
processed.append(("delta", msg.content))
# Should have sent coalesced delta
assert processed == [("delta", "AB")]
# Should have pending regular message
assert len(pending) == 1
assert pending[0].content == "Final"
@@ -0,0 +1,880 @@
"""Tests for channel plugin discovery, merging, and config compatibility."""
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.channels.manager import ChannelManager
from nanobot.config.schema import ChannelsConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakePlugin(BaseChannel):
name = "fakeplugin"
display_name = "Fake Plugin"
def __init__(self, config, bus):
super().__init__(config, bus)
self.login_calls: list[bool] = []
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
pass
async def login(self, force: bool = False) -> bool:
self.login_calls.append(force)
return True
class _FakeTelegram(BaseChannel):
"""Plugin that tries to shadow built-in telegram."""
name = "telegram"
display_name = "Fake Telegram"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
pass
def _make_entry_point(name: str, cls: type):
"""Create a mock entry point that returns *cls* on load()."""
ep = SimpleNamespace(name=name, load=lambda _cls=cls: _cls)
return ep
# ---------------------------------------------------------------------------
# ChannelsConfig extra="allow"
# ---------------------------------------------------------------------------
def test_channels_config_accepts_unknown_keys():
cfg = ChannelsConfig.model_validate({
"myplugin": {"enabled": True, "token": "abc"},
})
extra = cfg.model_extra
assert extra is not None
assert extra["myplugin"]["enabled"] is True
assert extra["myplugin"]["token"] == "abc"
def test_channels_config_getattr_returns_extra():
cfg = ChannelsConfig.model_validate({"myplugin": {"enabled": True}})
section = getattr(cfg, "myplugin", None)
assert isinstance(section, dict)
assert section["enabled"] is True
def test_channels_config_builtin_fields_removed():
"""After decoupling, ChannelsConfig has no explicit channel fields."""
cfg = ChannelsConfig()
assert not hasattr(cfg, "telegram")
assert cfg.send_progress is True
assert cfg.send_tool_hints is False
# ---------------------------------------------------------------------------
# discover_plugins
# ---------------------------------------------------------------------------
_EP_TARGET = "importlib.metadata.entry_points"
def test_discover_plugins_loads_entry_points():
from nanobot.channels.registry import discover_plugins
ep = _make_entry_point("line", _FakePlugin)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_plugins()
assert "line" in result
assert result["line"] is _FakePlugin
def test_discover_plugins_handles_load_error():
from nanobot.channels.registry import discover_plugins
def _boom():
raise RuntimeError("broken")
ep = SimpleNamespace(name="broken", load=_boom)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_plugins()
assert "broken" not in result
# ---------------------------------------------------------------------------
# discover_all — merge & priority
# ---------------------------------------------------------------------------
def test_discover_all_includes_builtins():
from nanobot.channels.registry import discover_all, discover_channel_names
with patch(_EP_TARGET, return_value=[]):
result = discover_all()
# discover_all() only returns channels that are actually available (dependencies installed)
# discover_channel_names() returns all built-in channel names
# So we check that all actually loaded channels are in the result
for name in result:
assert name in discover_channel_names()
def test_discover_all_includes_external_plugin():
from nanobot.channels.registry import discover_all
ep = _make_entry_point("line", _FakePlugin)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_all()
assert "line" in result
assert result["line"] is _FakePlugin
def test_discover_all_builtin_shadows_plugin():
from nanobot.channels.registry import discover_all
ep = _make_entry_point("telegram", _FakeTelegram)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_all()
assert "telegram" in result
assert result["telegram"] is not _FakeTelegram
# ---------------------------------------------------------------------------
# Manager _init_channels with dict config (plugin scenario)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_manager_loads_plugin_from_dict_config():
"""ChannelManager should instantiate a plugin channel from a raw dict config."""
from nanobot.channels.manager import ChannelManager
fake_config = SimpleNamespace(
channels=ChannelsConfig.model_validate({
"fakeplugin": {"enabled": True, "allowFrom": ["*"]},
}),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
with patch(
"nanobot.channels.registry.discover_all",
return_value={"fakeplugin": _FakePlugin},
):
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {}
mgr._dispatch_task = None
mgr._init_channels()
assert "fakeplugin" in mgr.channels
assert isinstance(mgr.channels["fakeplugin"], _FakePlugin)
def test_channels_login_uses_discovered_plugin_class(monkeypatch):
from nanobot.cli.commands import app
from nanobot.config.schema import Config
from typer.testing import CliRunner
runner = CliRunner()
seen: dict[str, object] = {}
class _LoginPlugin(_FakePlugin):
display_name = "Login Plugin"
async def login(self, force: bool = False) -> bool:
seen["force"] = force
seen["config"] = self.config
return True
monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config())
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {"fakeplugin": _LoginPlugin},
)
result = runner.invoke(app, ["channels", "login", "fakeplugin", "--force"])
assert result.exit_code == 0
assert seen["force"] is True
@pytest.mark.asyncio
async def test_manager_skips_disabled_plugin():
fake_config = SimpleNamespace(
channels=ChannelsConfig.model_validate({
"fakeplugin": {"enabled": False},
}),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
with patch(
"nanobot.channels.registry.discover_all",
return_value={"fakeplugin": _FakePlugin},
):
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {}
mgr._dispatch_task = None
mgr._init_channels()
assert "fakeplugin" not in mgr.channels
# ---------------------------------------------------------------------------
# Built-in channel default_config() and dict->Pydantic conversion
# ---------------------------------------------------------------------------
def test_builtin_channel_default_config():
"""Built-in channels expose default_config() returning a dict with 'enabled': False."""
from nanobot.channels.telegram import TelegramChannel
cfg = TelegramChannel.default_config()
assert isinstance(cfg, dict)
assert cfg["enabled"] is False
assert "token" in cfg
def test_builtin_channel_init_from_dict():
"""Built-in channels accept a raw dict and convert to Pydantic internally."""
from nanobot.channels.telegram import TelegramChannel
bus = MessageBus()
ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus)
assert ch.config.token == "test-tok"
assert ch.config.allow_from == ["*"]
def test_channels_config_send_max_retries_default():
"""ChannelsConfig should have send_max_retries with default value of 3."""
cfg = ChannelsConfig()
assert hasattr(cfg, 'send_max_retries')
assert cfg.send_max_retries == 3
def test_channels_config_send_max_retries_upper_bound():
"""send_max_retries should be bounded to prevent resource exhaustion."""
from pydantic import ValidationError
# Value too high should be rejected
with pytest.raises(ValidationError):
ChannelsConfig(send_max_retries=100)
# Negative should be rejected
with pytest.raises(ValidationError):
ChannelsConfig(send_max_retries=-1)
# Boundary values should be allowed
cfg_min = ChannelsConfig(send_max_retries=0)
assert cfg_min.send_max_retries == 0
cfg_max = ChannelsConfig(send_max_retries=10)
assert cfg_max.send_max_retries == 10
# Value above upper bound should be rejected
with pytest.raises(ValidationError):
ChannelsConfig(send_max_retries=11)
# ---------------------------------------------------------------------------
# _send_with_retry
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_with_retry_succeeds_first_try():
"""_send_with_retry should succeed on first try and not retry."""
call_count = 0
class _FailingChannel(BaseChannel):
name = "failing"
display_name = "Failing"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
nonlocal call_count
call_count += 1
# Succeeds on first try
fake_config = SimpleNamespace(
channels=ChannelsConfig(send_max_retries=3),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
mgr._dispatch_task = None
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
await mgr._send_with_retry(mgr.channels["failing"], msg)
assert call_count == 1
@pytest.mark.asyncio
async def test_send_with_retry_retries_on_failure():
"""_send_with_retry should retry on failure up to max_retries times."""
call_count = 0
class _FailingChannel(BaseChannel):
name = "failing"
display_name = "Failing"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
nonlocal call_count
call_count += 1
raise RuntimeError("simulated failure")
fake_config = SimpleNamespace(
channels=ChannelsConfig(send_max_retries=3),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
mgr._dispatch_task = None
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
# Patch asyncio.sleep to avoid actual delays
with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
await mgr._send_with_retry(mgr.channels["failing"], msg)
assert call_count == 3 # 3 total attempts (initial + 2 retries)
assert mock_sleep.call_count == 2 # 2 sleeps between retries
@pytest.mark.asyncio
async def test_send_with_retry_no_retry_when_max_is_zero():
"""_send_with_retry should not retry when send_max_retries is 0."""
call_count = 0
class _FailingChannel(BaseChannel):
name = "failing"
display_name = "Failing"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
nonlocal call_count
call_count += 1
raise RuntimeError("simulated failure")
fake_config = SimpleNamespace(
channels=ChannelsConfig(send_max_retries=0),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
mgr._dispatch_task = None
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock):
await mgr._send_with_retry(mgr.channels["failing"], msg)
assert call_count == 1 # Called once but no retry (max(0, 1) = 1)
@pytest.mark.asyncio
async def test_send_with_retry_calls_send_delta():
"""_send_with_retry should call send_delta when metadata has _stream_delta."""
send_delta_called = False
class _StreamingChannel(BaseChannel):
name = "streaming"
display_name = "Streaming"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
pass # Should not be called
async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None:
nonlocal send_delta_called
send_delta_called = True
fake_config = SimpleNamespace(
channels=ChannelsConfig(send_max_retries=3),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {"streaming": _StreamingChannel(fake_config, mgr.bus)}
mgr._dispatch_task = None
msg = OutboundMessage(
channel="streaming", chat_id="123", content="test delta",
metadata={"_stream_delta": True}
)
await mgr._send_with_retry(mgr.channels["streaming"], msg)
assert send_delta_called is True
@pytest.mark.asyncio
async def test_send_with_retry_skips_send_when_streamed():
"""_send_with_retry should not call send when metadata has _streamed flag."""
send_called = False
send_delta_called = False
class _StreamedChannel(BaseChannel):
name = "streamed"
display_name = "Streamed"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
nonlocal send_called
send_called = True
async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None:
nonlocal send_delta_called
send_delta_called = True
fake_config = SimpleNamespace(
channels=ChannelsConfig(send_max_retries=3),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {"streamed": _StreamedChannel(fake_config, mgr.bus)}
mgr._dispatch_task = None
# _streamed means message was already sent via send_delta, so skip send
msg = OutboundMessage(
channel="streamed", chat_id="123", content="test",
metadata={"_streamed": True}
)
await mgr._send_with_retry(mgr.channels["streamed"], msg)
assert send_called is False
assert send_delta_called is False
@pytest.mark.asyncio
async def test_send_with_retry_propagates_cancelled_error():
"""_send_with_retry should re-raise CancelledError for graceful shutdown."""
class _CancellingChannel(BaseChannel):
name = "cancelling"
display_name = "Cancelling"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
raise asyncio.CancelledError("simulated cancellation")
fake_config = SimpleNamespace(
channels=ChannelsConfig(send_max_retries=3),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {"cancelling": _CancellingChannel(fake_config, mgr.bus)}
mgr._dispatch_task = None
msg = OutboundMessage(channel="cancelling", chat_id="123", content="test")
with pytest.raises(asyncio.CancelledError):
await mgr._send_with_retry(mgr.channels["cancelling"], msg)
@pytest.mark.asyncio
async def test_send_with_retry_propagates_cancelled_error_during_sleep():
"""_send_with_retry should re-raise CancelledError during sleep."""
call_count = 0
class _FailingChannel(BaseChannel):
name = "failing"
display_name = "Failing"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
nonlocal call_count
call_count += 1
raise RuntimeError("simulated failure")
fake_config = SimpleNamespace(
channels=ChannelsConfig(send_max_retries=3),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
mgr._dispatch_task = None
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
# Mock sleep to raise CancelledError
async def cancel_during_sleep(_):
raise asyncio.CancelledError("cancelled during sleep")
with patch("nanobot.channels.manager.asyncio.sleep", side_effect=cancel_during_sleep):
with pytest.raises(asyncio.CancelledError):
await mgr._send_with_retry(mgr.channels["failing"], msg)
# Should have attempted once before sleep was cancelled
assert call_count == 1
# ---------------------------------------------------------------------------
# ChannelManager - lifecycle and getters
# ---------------------------------------------------------------------------
class _ChannelWithAllowFrom(BaseChannel):
"""Channel with configurable allow_from."""
name = "withallow"
display_name = "With Allow"
def __init__(self, config, bus, allow_from):
super().__init__(config, bus)
self.config.allow_from = allow_from
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
pass
class _StartableChannel(BaseChannel):
"""Channel that tracks start/stop calls."""
name = "startable"
display_name = "Startable"
def __init__(self, config, bus):
super().__init__(config, bus)
self.started = False
self.stopped = False
async def start(self) -> None:
self.started = True
async def stop(self) -> None:
self.stopped = True
async def send(self, msg: OutboundMessage) -> None:
pass
@pytest.mark.asyncio
async def test_validate_allow_from_raises_on_empty_list():
"""_validate_allow_from should raise SystemExit when allow_from is empty list."""
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])}
mgr._dispatch_task = None
with pytest.raises(SystemExit) as exc_info:
mgr._validate_allow_from()
assert "empty allowFrom" in str(exc_info.value)
@pytest.mark.asyncio
async def test_validate_allow_from_passes_with_asterisk():
"""_validate_allow_from should not raise when allow_from contains '*'."""
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, ["*"])}
mgr._dispatch_task = None
# Should not raise
mgr._validate_allow_from()
@pytest.mark.asyncio
async def test_get_channel_returns_channel_if_exists():
"""get_channel should return the channel if it exists."""
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {"telegram": _StartableChannel(fake_config, mgr.bus)}
mgr._dispatch_task = None
assert mgr.get_channel("telegram") is not None
assert mgr.get_channel("nonexistent") is None
@pytest.mark.asyncio
async def test_get_status_returns_running_state():
"""get_status should return enabled and running state for each channel."""
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
ch = _StartableChannel(fake_config, mgr.bus)
mgr.channels = {"startable": ch}
mgr._dispatch_task = None
status = mgr.get_status()
assert status["startable"]["enabled"] is True
assert status["startable"]["running"] is False # Not started yet
@pytest.mark.asyncio
async def test_enabled_channels_returns_channel_names():
"""enabled_channels should return list of enabled channel names."""
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {
"telegram": _StartableChannel(fake_config, mgr.bus),
"slack": _StartableChannel(fake_config, mgr.bus),
}
mgr._dispatch_task = None
enabled = mgr.enabled_channels
assert "telegram" in enabled
assert "slack" in enabled
assert len(enabled) == 2
@pytest.mark.asyncio
async def test_stop_all_cancels_dispatcher_and_stops_channels():
"""stop_all should cancel the dispatch task and stop all channels."""
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
ch = _StartableChannel(fake_config, mgr.bus)
mgr.channels = {"startable": ch}
# Create a real cancelled task
async def dummy_task():
while True:
await asyncio.sleep(1)
dispatch_task = asyncio.create_task(dummy_task())
mgr._dispatch_task = dispatch_task
await mgr.stop_all()
# Task should be cancelled
assert dispatch_task.cancelled()
# Channel should be stopped
assert ch.stopped is True
@pytest.mark.asyncio
async def test_start_channel_logs_error_on_failure():
"""_start_channel should log error when channel start fails."""
class _FailingChannel(BaseChannel):
name = "failing"
display_name = "Failing"
async def start(self) -> None:
raise RuntimeError("connection failed")
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
pass
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {}
mgr._dispatch_task = None
ch = _FailingChannel(fake_config, mgr.bus)
# Should not raise, just log error
await mgr._start_channel("failing", ch)
@pytest.mark.asyncio
async def test_stop_all_handles_channel_exception():
"""stop_all should handle exceptions when stopping channels gracefully."""
class _StopFailingChannel(BaseChannel):
name = "stopfailing"
display_name = "Stop Failing"
async def start(self) -> None:
pass
async def stop(self) -> None:
raise RuntimeError("stop failed")
async def send(self, msg: OutboundMessage) -> None:
pass
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {"stopfailing": _StopFailingChannel(fake_config, mgr.bus)}
mgr._dispatch_task = None
# Should not raise even if channel.stop() raises
await mgr.stop_all()
@pytest.mark.asyncio
async def test_start_all_no_channels_logs_warning():
"""start_all should log warning when no channels are enabled."""
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {} # No channels
mgr._dispatch_task = None
# Should return early without creating dispatch task
await mgr.start_all()
assert mgr._dispatch_task is None
@pytest.mark.asyncio
async def test_start_all_creates_dispatch_task():
"""start_all should create the dispatch task when channels exist."""
fake_config = SimpleNamespace(
channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
ch = _StartableChannel(fake_config, mgr.bus)
mgr.channels = {"startable": ch}
mgr._dispatch_task = None
# Cancel immediately after start to avoid running forever
async def cancel_after_start():
await asyncio.sleep(0.01)
if mgr._dispatch_task:
mgr._dispatch_task.cancel()
cancel_task = asyncio.create_task(cancel_after_start())
try:
await mgr.start_all()
except asyncio.CancelledError:
pass
finally:
cancel_task.cancel()
try:
await cancel_task
except asyncio.CancelledError:
pass
# Dispatch task should have been created
assert mgr._dispatch_task is not None
@@ -0,0 +1,223 @@
import asyncio
from types import SimpleNamespace
import pytest
# Check optional dingtalk dependencies before running tests
try:
from nanobot.channels import dingtalk
DINGTALK_AVAILABLE = getattr(dingtalk, "DINGTALK_AVAILABLE", False)
except ImportError:
DINGTALK_AVAILABLE = False
if not DINGTALK_AVAILABLE:
pytest.skip("DingTalk dependencies not installed (dingtalk-stream)", allow_module_level=True)
from nanobot.bus.queue import MessageBus
import nanobot.channels.dingtalk as dingtalk_module
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
from nanobot.channels.dingtalk import DingTalkConfig
class _FakeResponse:
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
self.status_code = status_code
self._json_body = json_body or {}
self.text = "{}"
self.content = b""
self.headers = {"content-type": "application/json"}
def json(self) -> dict:
return self._json_body
class _FakeHttp:
def __init__(self, responses: list[_FakeResponse] | None = None) -> None:
self.calls: list[dict] = []
self._responses = list(responses) if responses else []
def _next_response(self) -> _FakeResponse:
if self._responses:
return self._responses.pop(0)
return _FakeResponse()
async def post(self, url: str, json=None, headers=None, **kwargs):
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
return self._next_response()
async def get(self, url: str, **kwargs):
self.calls.append({"method": "GET", "url": url})
return self._next_response()
@pytest.mark.asyncio
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
bus = MessageBus()
channel = DingTalkChannel(config, bus)
await channel._on_message(
"hello",
sender_id="user1",
sender_name="Alice",
conversation_type="2",
conversation_id="conv123",
)
msg = await bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123"
assert msg.metadata["conversation_type"] == "2"
@pytest.mark.asyncio
async def test_group_send_uses_group_messages_api() -> None:
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
channel = DingTalkChannel(config, MessageBus())
channel._http = _FakeHttp()
ok = await channel._send_batch_message(
"token",
"group:conv123",
"sampleMarkdown",
{"text": "hello", "title": "Nanobot Reply"},
)
assert ok is True
call = channel._http.calls[0]
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
assert call["json"]["openConversationId"] == "conv123"
assert call["json"]["msgKey"] == "sampleMarkdown"
@pytest.mark.asyncio
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
bus = MessageBus()
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
bus,
)
handler = NanobotDingTalkHandler(channel)
class _FakeChatbotMessage:
text = None
extensions = {"content": {"recognition": "voice transcript"}}
sender_staff_id = "user1"
sender_id = "fallback-user"
sender_nick = "Alice"
message_type = "audio"
@staticmethod
def from_dict(_data):
return _FakeChatbotMessage()
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
status, body = await handler.process(
SimpleNamespace(
data={
"conversationType": "2",
"conversationId": "conv123",
"text": {"content": ""},
}
)
)
await asyncio.gather(*list(channel._background_tasks))
msg = await bus.consume_inbound()
assert (status, body) == ("OK", "OK")
assert msg.content == "voice transcript"
assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123"
@pytest.mark.asyncio
async def test_handler_processes_file_message(monkeypatch) -> None:
"""Test that file messages are handled and forwarded with downloaded path."""
bus = MessageBus()
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
bus,
)
handler = NanobotDingTalkHandler(channel)
class _FakeFileChatbotMessage:
text = None
extensions = {}
image_content = None
rich_text_content = None
sender_staff_id = "user1"
sender_id = "fallback-user"
sender_nick = "Alice"
message_type = "file"
@staticmethod
def from_dict(_data):
return _FakeFileChatbotMessage()
async def fake_download(download_code, filename, sender_id):
return f"/tmp/nanobot_dingtalk/{sender_id}/{filename}"
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeFileChatbotMessage)
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
monkeypatch.setattr(channel, "_download_dingtalk_file", fake_download)
status, body = await handler.process(
SimpleNamespace(
data={
"conversationType": "1",
"content": {"downloadCode": "abc123", "fileName": "report.xlsx"},
"text": {"content": ""},
}
)
)
await asyncio.gather(*list(channel._background_tasks))
msg = await bus.consume_inbound()
assert (status, body) == ("OK", "OK")
assert "[File]" in msg.content
assert "/tmp/nanobot_dingtalk/user1/report.xlsx" in msg.content
@pytest.mark.asyncio
async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
"""Test the two-step file download flow (get URL then download content)."""
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
MessageBus(),
)
# Mock access token
async def fake_get_token():
return "test-token"
monkeypatch.setattr(channel, "_get_access_token", fake_get_token)
# Mock HTTP: first POST returns downloadUrl, then GET returns file bytes
file_content = b"fake file content"
channel._http = _FakeHttp(responses=[
_FakeResponse(200, {"downloadUrl": "https://example.com/tmpfile"}),
_FakeResponse(200),
])
channel._http._responses[1].content = file_content
# Redirect media dir to tmp_path
monkeypatch.setattr(
"nanobot.config.paths.get_media_dir",
lambda channel_name=None: tmp_path / channel_name if channel_name else tmp_path,
)
result = await channel._download_dingtalk_file("code123", "test.xlsx", "user1")
assert result is not None
assert result.endswith("test.xlsx")
assert (tmp_path / "dingtalk" / "user1" / "test.xlsx").read_bytes() == file_content
# Verify API calls
assert channel._http.calls[0]["method"] == "POST"
assert "messageFiles/download" in channel._http.calls[0]["url"]
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
assert channel._http.calls[1]["method"] == "GET"
@@ -0,0 +1,652 @@
from email.message import EmailMessage
from datetime import date
import imaplib
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.email import EmailChannel
from nanobot.channels.email import EmailConfig
def _make_config(**overrides) -> EmailConfig:
defaults = dict(
enabled=True,
consent_granted=True,
imap_host="imap.example.com",
imap_port=993,
imap_username="bot@example.com",
imap_password="secret",
smtp_host="smtp.example.com",
smtp_port=587,
smtp_username="bot@example.com",
smtp_password="secret",
mark_seen=True,
# Disable auth verification by default so existing tests are unaffected
verify_dkim=False,
verify_spf=False,
)
defaults.update(overrides)
return EmailConfig(**defaults)
def _make_raw_email(
from_addr: str = "alice@example.com",
subject: str = "Hello",
body: str = "This is the body.",
auth_results: str | None = None,
) -> bytes:
msg = EmailMessage()
msg["From"] = from_addr
msg["To"] = "bot@example.com"
msg["Subject"] = subject
msg["Message-ID"] = "<m1@example.com>"
if auth_results:
msg["Authentication-Results"] = auth_results
msg.set_content(body)
return msg.as_bytes()
def test_fetch_new_messages_parses_unseen_and_marks_seen(monkeypatch) -> None:
raw = _make_raw_email(subject="Invoice", body="Please pay")
class FakeIMAP:
def __init__(self) -> None:
self.store_calls: list[tuple[bytes, str, str]] = []
def login(self, _user: str, _pw: str):
return "OK", [b"logged in"]
def select(self, _mailbox: str):
return "OK", [b"1"]
def search(self, *_args):
return "OK", [b"1"]
def fetch(self, _imap_id: bytes, _parts: str):
return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
def store(self, imap_id: bytes, op: str, flags: str):
self.store_calls.append((imap_id, op, flags))
return "OK", [b""]
def logout(self):
return "BYE", [b""]
fake = FakeIMAP()
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
channel = EmailChannel(_make_config(), MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert items[0]["sender"] == "alice@example.com"
assert items[0]["subject"] == "Invoice"
assert "Please pay" in items[0]["content"]
assert fake.store_calls == [(b"1", "+FLAGS", "\\Seen")]
# Same UID should be deduped in-process.
items_again = channel._fetch_new_messages()
assert items_again == []
def test_fetch_new_messages_retries_once_when_imap_connection_goes_stale(monkeypatch) -> None:
raw = _make_raw_email(subject="Invoice", body="Please pay")
fail_once = {"pending": True}
class FlakyIMAP:
def __init__(self) -> None:
self.store_calls: list[tuple[bytes, str, str]] = []
self.search_calls = 0
def login(self, _user: str, _pw: str):
return "OK", [b"logged in"]
def select(self, _mailbox: str):
return "OK", [b"1"]
def search(self, *_args):
self.search_calls += 1
if fail_once["pending"]:
fail_once["pending"] = False
raise imaplib.IMAP4.abort("socket error")
return "OK", [b"1"]
def fetch(self, _imap_id: bytes, _parts: str):
return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
def store(self, imap_id: bytes, op: str, flags: str):
self.store_calls.append((imap_id, op, flags))
return "OK", [b""]
def logout(self):
return "BYE", [b""]
fake_instances: list[FlakyIMAP] = []
def _factory(_host: str, _port: int):
instance = FlakyIMAP()
fake_instances.append(instance)
return instance
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", _factory)
channel = EmailChannel(_make_config(), MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert len(fake_instances) == 2
assert fake_instances[0].search_calls == 1
assert fake_instances[1].search_calls == 1
def test_fetch_new_messages_keeps_messages_collected_before_stale_retry(monkeypatch) -> None:
raw_first = _make_raw_email(subject="First", body="First body")
raw_second = _make_raw_email(subject="Second", body="Second body")
mailbox_state = {
b"1": {"uid": b"123", "raw": raw_first, "seen": False},
b"2": {"uid": b"124", "raw": raw_second, "seen": False},
}
fail_once = {"pending": True}
class FlakyIMAP:
def login(self, _user: str, _pw: str):
return "OK", [b"logged in"]
def select(self, _mailbox: str):
return "OK", [b"2"]
def search(self, *_args):
unseen_ids = [imap_id for imap_id, item in mailbox_state.items() if not item["seen"]]
return "OK", [b" ".join(unseen_ids)]
def fetch(self, imap_id: bytes, _parts: str):
if imap_id == b"2" and fail_once["pending"]:
fail_once["pending"] = False
raise imaplib.IMAP4.abort("socket error")
item = mailbox_state[imap_id]
header = b"%s (UID %s BODY[] {200})" % (imap_id, item["uid"])
return "OK", [(header, item["raw"]), b")"]
def store(self, imap_id: bytes, _op: str, _flags: str):
mailbox_state[imap_id]["seen"] = True
return "OK", [b""]
def logout(self):
return "BYE", [b""]
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: FlakyIMAP())
channel = EmailChannel(_make_config(), MessageBus())
items = channel._fetch_new_messages()
assert [item["subject"] for item in items] == ["First", "Second"]
def test_fetch_new_messages_skips_missing_mailbox(monkeypatch) -> None:
class MissingMailboxIMAP:
def login(self, _user: str, _pw: str):
return "OK", [b"logged in"]
def select(self, _mailbox: str):
raise imaplib.IMAP4.error("Mailbox doesn't exist")
def logout(self):
return "BYE", [b""]
monkeypatch.setattr(
"nanobot.channels.email.imaplib.IMAP4_SSL",
lambda _h, _p: MissingMailboxIMAP(),
)
channel = EmailChannel(_make_config(), MessageBus())
assert channel._fetch_new_messages() == []
def test_extract_text_body_falls_back_to_html() -> None:
msg = EmailMessage()
msg["From"] = "alice@example.com"
msg["To"] = "bot@example.com"
msg["Subject"] = "HTML only"
msg.add_alternative("<p>Hello<br>world</p>", subtype="html")
text = EmailChannel._extract_text_body(msg)
assert "Hello" in text
assert "world" in text
@pytest.mark.asyncio
async def test_start_returns_immediately_without_consent(monkeypatch) -> None:
cfg = _make_config()
cfg.consent_granted = False
channel = EmailChannel(cfg, MessageBus())
called = {"fetch": False}
def _fake_fetch():
called["fetch"] = True
return []
monkeypatch.setattr(channel, "_fetch_new_messages", _fake_fetch)
await channel.start()
assert channel.is_running is False
assert called["fetch"] is False
@pytest.mark.asyncio
async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.timeout = timeout
self.started_tls = False
self.logged_in = False
self.sent_messages: list[EmailMessage] = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def starttls(self, context=None):
self.started_tls = True
def login(self, _user: str, _pw: str):
self.logged_in = True
def send_message(self, msg: EmailMessage):
self.sent_messages.append(msg)
fake_instances: list[FakeSMTP] = []
def _smtp_factory(host: str, port: int, timeout: int = 30):
instance = FakeSMTP(host, port, timeout=timeout)
fake_instances.append(instance)
return instance
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
channel = EmailChannel(_make_config(), MessageBus())
channel._last_subject_by_chat["alice@example.com"] = "Invoice #42"
channel._last_message_id_by_chat["alice@example.com"] = "<m1@example.com>"
await channel.send(
OutboundMessage(
channel="email",
chat_id="alice@example.com",
content="Acknowledged.",
)
)
assert len(fake_instances) == 1
smtp = fake_instances[0]
assert smtp.started_tls is True
assert smtp.logged_in is True
assert len(smtp.sent_messages) == 1
sent = smtp.sent_messages[0]
assert sent["Subject"] == "Re: Invoice #42"
assert sent["To"] == "alice@example.com"
assert sent["In-Reply-To"] == "<m1@example.com>"
@pytest.mark.asyncio
async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.sent_messages: list[EmailMessage] = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def starttls(self, context=None):
return None
def login(self, _user: str, _pw: str):
return None
def send_message(self, msg: EmailMessage):
self.sent_messages.append(msg)
fake_instances: list[FakeSMTP] = []
def _smtp_factory(host: str, port: int, timeout: int = 30):
instance = FakeSMTP(host, port, timeout=timeout)
fake_instances.append(instance)
return instance
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
cfg = _make_config()
cfg.auto_reply_enabled = False
channel = EmailChannel(cfg, MessageBus())
# Mark alice as someone who sent us an email (making this a "reply")
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
# Reply should be skipped (auto_reply_enabled=False)
await channel.send(
OutboundMessage(
channel="email",
chat_id="alice@example.com",
content="Should not send.",
)
)
assert fake_instances == []
# Reply with force_send=True should be sent
await channel.send(
OutboundMessage(
channel="email",
chat_id="alice@example.com",
content="Force send.",
metadata={"force_send": True},
)
)
assert len(fake_instances) == 1
assert len(fake_instances[0].sent_messages) == 1
@pytest.mark.asyncio
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.sent_messages: list[EmailMessage] = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def starttls(self, context=None):
return None
def login(self, _user: str, _pw: str):
return None
def send_message(self, msg: EmailMessage):
self.sent_messages.append(msg)
fake_instances: list[FakeSMTP] = []
def _smtp_factory(host: str, port: int, timeout: int = 30):
instance = FakeSMTP(host, port, timeout=timeout)
fake_instances.append(instance)
return instance
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
cfg = _make_config()
cfg.auto_reply_enabled = False
channel = EmailChannel(cfg, MessageBus())
# bob@example.com has never sent us an email (proactive send)
# This should be sent even with auto_reply_enabled=False
await channel.send(
OutboundMessage(
channel="email",
chat_id="bob@example.com",
content="Hello, this is a proactive email.",
)
)
assert len(fake_instances) == 1
assert len(fake_instances[0].sent_messages) == 1
sent = fake_instances[0].sent_messages[0]
assert sent["To"] == "bob@example.com"
@pytest.mark.asyncio
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.sent_messages: list[EmailMessage] = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def starttls(self, context=None):
return None
def login(self, _user: str, _pw: str):
return None
def send_message(self, msg: EmailMessage):
self.sent_messages.append(msg)
called = {"smtp": False}
def _smtp_factory(host: str, port: int, timeout: int = 30):
called["smtp"] = True
return FakeSMTP(host, port, timeout=timeout)
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
cfg = _make_config()
cfg.consent_granted = False
channel = EmailChannel(cfg, MessageBus())
await channel.send(
OutboundMessage(
channel="email",
chat_id="alice@example.com",
content="Should not send.",
metadata={"force_send": True},
)
)
assert called["smtp"] is False
def test_fetch_messages_between_dates_uses_imap_since_before_without_mark_seen(monkeypatch) -> None:
raw = _make_raw_email(subject="Status", body="Yesterday update")
class FakeIMAP:
def __init__(self) -> None:
self.search_args = None
self.store_calls: list[tuple[bytes, str, str]] = []
def login(self, _user: str, _pw: str):
return "OK", [b"logged in"]
def select(self, _mailbox: str):
return "OK", [b"1"]
def search(self, *_args):
self.search_args = _args
return "OK", [b"5"]
def fetch(self, _imap_id: bytes, _parts: str):
return "OK", [(b"5 (UID 999 BODY[] {200})", raw), b")"]
def store(self, imap_id: bytes, op: str, flags: str):
self.store_calls.append((imap_id, op, flags))
return "OK", [b""]
def logout(self):
return "BYE", [b""]
fake = FakeIMAP()
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
channel = EmailChannel(_make_config(), MessageBus())
items = channel.fetch_messages_between_dates(
start_date=date(2026, 2, 6),
end_date=date(2026, 2, 7),
limit=10,
)
assert len(items) == 1
assert items[0]["subject"] == "Status"
# search(None, "SINCE", "06-Feb-2026", "BEFORE", "07-Feb-2026")
assert fake.search_args is not None
assert fake.search_args[1:] == ("SINCE", "06-Feb-2026", "BEFORE", "07-Feb-2026")
assert fake.store_calls == []
# ---------------------------------------------------------------------------
# Security: Anti-spoofing tests for Authentication-Results verification
# ---------------------------------------------------------------------------
def _make_fake_imap(raw: bytes):
"""Return a FakeIMAP class pre-loaded with the given raw email."""
class FakeIMAP:
def __init__(self) -> None:
self.store_calls: list[tuple[bytes, str, str]] = []
def login(self, _user: str, _pw: str):
return "OK", [b"logged in"]
def select(self, _mailbox: str):
return "OK", [b"1"]
def search(self, *_args):
return "OK", [b"1"]
def fetch(self, _imap_id: bytes, _parts: str):
return "OK", [(b"1 (UID 500 BODY[] {200})", raw), b")"]
def store(self, imap_id: bytes, op: str, flags: str):
self.store_calls.append((imap_id, op, flags))
return "OK", [b""]
def logout(self):
return "BYE", [b""]
return FakeIMAP()
def test_spoofed_email_rejected_when_verify_enabled(monkeypatch) -> None:
"""An email without Authentication-Results should be rejected when verify_dkim=True."""
raw = _make_raw_email(subject="Spoofed", body="Malicious payload")
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(verify_dkim=True, verify_spf=True)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 0, "Spoofed email without auth headers should be rejected"
def test_email_with_valid_auth_results_accepted(monkeypatch) -> None:
"""An email with spf=pass and dkim=pass should be accepted."""
raw = _make_raw_email(
subject="Legit",
body="Hello from verified sender",
auth_results="mx.example.com; spf=pass smtp.mailfrom=alice@example.com; dkim=pass header.d=example.com",
)
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(verify_dkim=True, verify_spf=True)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert items[0]["sender"] == "alice@example.com"
assert items[0]["subject"] == "Legit"
def test_email_with_partial_auth_rejected(monkeypatch) -> None:
"""An email with only spf=pass but no dkim=pass should be rejected when verify_dkim=True."""
raw = _make_raw_email(
subject="Partial",
body="Only SPF passes",
auth_results="mx.example.com; spf=pass smtp.mailfrom=alice@example.com; dkim=fail",
)
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(verify_dkim=True, verify_spf=True)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 0, "Email with dkim=fail should be rejected"
def test_backward_compat_verify_disabled(monkeypatch) -> None:
"""When verify_dkim=False and verify_spf=False, emails without auth headers are accepted."""
raw = _make_raw_email(subject="NoAuth", body="No auth headers present")
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(verify_dkim=False, verify_spf=False)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1, "With verification disabled, emails should be accepted as before"
def test_email_content_tagged_with_email_context(monkeypatch) -> None:
"""Email content should be prefixed with [EMAIL-CONTEXT] for LLM isolation."""
raw = _make_raw_email(subject="Tagged", body="Check the tag")
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(verify_dkim=False, verify_spf=False)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert items[0]["content"].startswith("[EMAIL-CONTEXT]"), (
"Email content must be tagged with [EMAIL-CONTEXT]"
)
def test_check_authentication_results_method() -> None:
"""Unit test for the _check_authentication_results static method."""
from email.parser import BytesParser
from email import policy
# No Authentication-Results header
msg_no_auth = EmailMessage()
msg_no_auth["From"] = "alice@example.com"
msg_no_auth.set_content("test")
parsed = BytesParser(policy=policy.default).parsebytes(msg_no_auth.as_bytes())
spf, dkim = EmailChannel._check_authentication_results(parsed)
assert spf is False
assert dkim is False
# Both pass
msg_both = EmailMessage()
msg_both["From"] = "alice@example.com"
msg_both["Authentication-Results"] = (
"mx.google.com; spf=pass smtp.mailfrom=example.com; dkim=pass header.d=example.com"
)
msg_both.set_content("test")
parsed = BytesParser(policy=policy.default).parsebytes(msg_both.as_bytes())
spf, dkim = EmailChannel._check_authentication_results(parsed)
assert spf is True
assert dkim is True
# SPF pass, DKIM fail
msg_spf_only = EmailMessage()
msg_spf_only["From"] = "alice@example.com"
msg_spf_only["Authentication-Results"] = (
"mx.google.com; spf=pass smtp.mailfrom=example.com; dkim=fail"
)
msg_spf_only.set_content("test")
parsed = BytesParser(policy=policy.default).parsebytes(msg_spf_only.as_bytes())
spf, dkim = EmailChannel._check_authentication_results(parsed)
assert spf is True
assert dkim is False
# DKIM pass, SPF fail
msg_dkim_only = EmailMessage()
msg_dkim_only["From"] = "alice@example.com"
msg_dkim_only["Authentication-Results"] = (
"mx.google.com; spf=fail smtp.mailfrom=example.com; dkim=pass header.d=example.com"
)
msg_dkim_only.set_content("test")
parsed = BytesParser(policy=policy.default).parsebytes(msg_dkim_only.as_bytes())
spf, dkim = EmailChannel._check_authentication_results(parsed)
assert spf is False
assert dkim is True
@@ -0,0 +1,68 @@
# Check optional Feishu dependencies before running tests
try:
from nanobot.channels import feishu
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
except ImportError:
FEISHU_AVAILABLE = False
if not FEISHU_AVAILABLE:
import pytest
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
from nanobot.channels.feishu import FeishuChannel
def test_parse_md_table_strips_markdown_formatting_in_headers_and_cells() -> None:
table = FeishuChannel._parse_md_table(
"""
| **Name** | __Status__ | *Notes* | ~~State~~ |
| --- | --- | --- | --- |
| **Alice** | __Ready__ | *Fast* | ~~Old~~ |
"""
)
assert table is not None
assert [col["display_name"] for col in table["columns"]] == [
"Name",
"Status",
"Notes",
"State",
]
assert table["rows"] == [
{"c0": "Alice", "c1": "Ready", "c2": "Fast", "c3": "Old"}
]
def test_split_headings_strips_embedded_markdown_before_bolding() -> None:
channel = FeishuChannel.__new__(FeishuChannel)
elements = channel._split_headings("# **Important** *status* ~~update~~")
assert elements == [
{
"tag": "div",
"text": {
"tag": "lark_md",
"content": "**Important status update**",
},
}
]
def test_split_headings_keeps_markdown_body_and_code_blocks_intact() -> None:
channel = FeishuChannel.__new__(FeishuChannel)
elements = channel._split_headings(
"# **Heading**\n\nBody with **bold** text.\n\n```python\nprint('hi')\n```"
)
assert elements[0] == {
"tag": "div",
"text": {
"tag": "lark_md",
"content": "**Heading**",
},
}
assert elements[1]["tag"] == "markdown"
assert "Body with **bold** text." in elements[1]["content"]
assert "```python\nprint('hi')\n```" in elements[1]["content"]
@@ -0,0 +1,76 @@
# Check optional Feishu dependencies before running tests
try:
from nanobot.channels import feishu
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
except ImportError:
FEISHU_AVAILABLE = False
if not FEISHU_AVAILABLE:
import pytest
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
def test_extract_post_content_supports_post_wrapper_shape() -> None:
payload = {
"post": {
"zh_cn": {
"title": "日报",
"content": [
[
{"tag": "text", "text": "完成"},
{"tag": "img", "image_key": "img_1"},
]
],
}
}
}
text, image_keys = _extract_post_content(payload)
assert text == "日报 完成"
assert image_keys == ["img_1"]
def test_extract_post_content_keeps_direct_shape_behavior() -> None:
payload = {
"title": "Daily",
"content": [
[
{"tag": "text", "text": "report"},
{"tag": "img", "image_key": "img_a"},
{"tag": "img", "image_key": "img_b"},
]
],
}
text, image_keys = _extract_post_content(payload)
assert text == "Daily report"
assert image_keys == ["img_a", "img_b"]
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
class Builder:
pass
builder = Builder()
same = FeishuChannel._register_optional_event(builder, "missing", object())
assert same is builder
def test_register_optional_event_calls_supported_method() -> None:
called = []
class Builder:
def register_event(self, handler):
called.append(handler)
return self
builder = Builder()
handler = object()
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
assert same is builder
assert called == [handler]
@@ -0,0 +1,445 @@
"""Tests for Feishu message reply (quote) feature."""
import asyncio
import json
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
# Check optional Feishu dependencies before running tests
try:
from nanobot.channels import feishu
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
except ImportError:
FEISHU_AVAILABLE = False
if not FEISHU_AVAILABLE:
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel:
config = FeishuConfig(
enabled=True,
app_id="cli_test",
app_secret="secret",
allow_from=["*"],
reply_to_message=reply_to_message,
)
channel = FeishuChannel(config, MessageBus())
channel._client = MagicMock()
# _loop is only used by the WebSocket thread bridge; not needed for unit tests
channel._loop = None
return channel
def _make_feishu_event(
*,
message_id: str = "om_001",
chat_id: str = "oc_abc",
chat_type: str = "p2p",
msg_type: str = "text",
content: str = '{"text": "hello"}',
sender_open_id: str = "ou_alice",
parent_id: str | None = None,
root_id: str | None = None,
):
message = SimpleNamespace(
message_id=message_id,
chat_id=chat_id,
chat_type=chat_type,
message_type=msg_type,
content=content,
parent_id=parent_id,
root_id=root_id,
mentions=[],
)
sender = SimpleNamespace(
sender_type="user",
sender_id=SimpleNamespace(open_id=sender_open_id),
)
return SimpleNamespace(event=SimpleNamespace(message=message, sender=sender))
def _make_get_message_response(text: str, msg_type: str = "text", success: bool = True):
"""Build a fake im.v1.message.get response object."""
body = SimpleNamespace(content=json.dumps({"text": text}))
item = SimpleNamespace(msg_type=msg_type, body=body)
data = SimpleNamespace(items=[item])
resp = MagicMock()
resp.success.return_value = success
resp.data = data
resp.code = 0
resp.msg = "ok"
return resp
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
def test_feishu_config_reply_to_message_defaults_false() -> None:
assert FeishuConfig().reply_to_message is False
def test_feishu_config_reply_to_message_can_be_enabled() -> None:
config = FeishuConfig(reply_to_message=True)
assert config.reply_to_message is True
# ---------------------------------------------------------------------------
# _get_message_content_sync tests
# ---------------------------------------------------------------------------
def test_get_message_content_sync_returns_reply_prefix() -> None:
channel = _make_feishu_channel()
channel._client.im.v1.message.get.return_value = _make_get_message_response("what time is it?")
result = channel._get_message_content_sync("om_parent")
assert result == "[Reply to: what time is it?]"
def test_get_message_content_sync_truncates_long_text() -> None:
channel = _make_feishu_channel()
long_text = "x" * (FeishuChannel._REPLY_CONTEXT_MAX_LEN + 50)
channel._client.im.v1.message.get.return_value = _make_get_message_response(long_text)
result = channel._get_message_content_sync("om_parent")
assert result is not None
assert result.endswith("...]")
inner = result[len("[Reply to: ") : -1]
assert len(inner) == FeishuChannel._REPLY_CONTEXT_MAX_LEN + len("...")
def test_get_message_content_sync_returns_none_on_api_failure() -> None:
channel = _make_feishu_channel()
resp = MagicMock()
resp.success.return_value = False
resp.code = 230002
resp.msg = "bot not in group"
channel._client.im.v1.message.get.return_value = resp
result = channel._get_message_content_sync("om_parent")
assert result is None
def test_get_message_content_sync_returns_none_for_non_text_type() -> None:
channel = _make_feishu_channel()
body = SimpleNamespace(content=json.dumps({"image_key": "img_1"}))
item = SimpleNamespace(msg_type="image", body=body)
data = SimpleNamespace(items=[item])
resp = MagicMock()
resp.success.return_value = True
resp.data = data
channel._client.im.v1.message.get.return_value = resp
result = channel._get_message_content_sync("om_parent")
assert result is None
def test_get_message_content_sync_returns_none_when_empty_text() -> None:
channel = _make_feishu_channel()
channel._client.im.v1.message.get.return_value = _make_get_message_response(" ")
result = channel._get_message_content_sync("om_parent")
assert result is None
# ---------------------------------------------------------------------------
# _reply_message_sync tests
# ---------------------------------------------------------------------------
def test_reply_message_sync_returns_true_on_success() -> None:
channel = _make_feishu_channel()
resp = MagicMock()
resp.success.return_value = True
channel._client.im.v1.message.reply.return_value = resp
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
assert ok is True
channel._client.im.v1.message.reply.assert_called_once()
def test_reply_message_sync_returns_false_on_api_error() -> None:
channel = _make_feishu_channel()
resp = MagicMock()
resp.success.return_value = False
resp.code = 400
resp.msg = "bad request"
resp.get_log_id.return_value = "log_x"
channel._client.im.v1.message.reply.return_value = resp
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
assert ok is False
def test_reply_message_sync_returns_false_on_exception() -> None:
channel = _make_feishu_channel()
channel._client.im.v1.message.reply.side_effect = RuntimeError("network error")
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
assert ok is False
@pytest.mark.asyncio
@pytest.mark.parametrize(
("filename", "expected_msg_type"),
[
("voice.opus", "audio"),
("clip.mp4", "video"),
("report.pdf", "file"),
],
)
async def test_send_uses_expected_feishu_msg_type_for_uploaded_files(
tmp_path: Path, filename: str, expected_msg_type: str
) -> None:
channel = _make_feishu_channel()
file_path = tmp_path / filename
file_path.write_bytes(b"demo")
send_calls: list[tuple[str, str, str, str]] = []
def _record_send(receive_id_type: str, receive_id: str, msg_type: str, content: str) -> None:
send_calls.append((receive_id_type, receive_id, msg_type, content))
with patch.object(channel, "_upload_file_sync", return_value="file-key"), patch.object(
channel, "_send_message_sync", side_effect=_record_send
):
await channel.send(
OutboundMessage(
channel="feishu",
chat_id="oc_test",
content="",
media=[str(file_path)],
metadata={},
)
)
assert len(send_calls) == 1
receive_id_type, receive_id, msg_type, content = send_calls[0]
assert receive_id_type == "chat_id"
assert receive_id == "oc_test"
assert msg_type == expected_msg_type
assert json.loads(content) == {"file_key": "file-key"}
# ---------------------------------------------------------------------------
# send() — reply routing tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_uses_reply_api_when_configured() -> None:
channel = _make_feishu_channel(reply_to_message=True)
reply_resp = MagicMock()
reply_resp.success.return_value = True
channel._client.im.v1.message.reply.return_value = reply_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={"message_id": "om_001"},
))
channel._client.im.v1.message.reply.assert_called_once()
channel._client.im.v1.message.create.assert_not_called()
@pytest.mark.asyncio
async def test_send_uses_create_api_when_reply_disabled() -> None:
channel = _make_feishu_channel(reply_to_message=False)
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={"message_id": "om_001"},
))
channel._client.im.v1.message.create.assert_called_once()
channel._client.im.v1.message.reply.assert_not_called()
@pytest.mark.asyncio
async def test_send_uses_create_api_when_no_message_id() -> None:
channel = _make_feishu_channel(reply_to_message=True)
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={},
))
channel._client.im.v1.message.create.assert_called_once()
channel._client.im.v1.message.reply.assert_not_called()
@pytest.mark.asyncio
async def test_send_skips_reply_for_progress_messages() -> None:
channel = _make_feishu_channel(reply_to_message=True)
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="thinking...",
metadata={"message_id": "om_001", "_progress": True},
))
channel._client.im.v1.message.create.assert_called_once()
channel._client.im.v1.message.reply.assert_not_called()
@pytest.mark.asyncio
async def test_send_fallback_to_create_when_reply_fails() -> None:
channel = _make_feishu_channel(reply_to_message=True)
reply_resp = MagicMock()
reply_resp.success.return_value = False
reply_resp.code = 400
reply_resp.msg = "error"
reply_resp.get_log_id.return_value = "log_x"
channel._client.im.v1.message.reply.return_value = reply_resp
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={"message_id": "om_001"},
))
# reply attempted first, then falls back to create
channel._client.im.v1.message.reply.assert_called_once()
channel._client.im.v1.message.create.assert_called_once()
# ---------------------------------------------------------------------------
# _on_message — parent_id / root_id metadata tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_on_message_captures_parent_and_root_id_in_metadata() -> None:
channel = _make_feishu_channel()
channel._processed_message_ids.clear()
channel._client.im.v1.message.react.return_value = MagicMock(success=lambda: True)
captured = []
async def _capture(**kwargs):
captured.append(kwargs)
channel._handle_message = _capture
with patch.object(channel, "_add_reaction", return_value=None):
await channel._on_message(
_make_feishu_event(
parent_id="om_parent",
root_id="om_root",
)
)
assert len(captured) == 1
meta = captured[0]["metadata"]
assert meta["parent_id"] == "om_parent"
assert meta["root_id"] == "om_root"
assert meta["message_id"] == "om_001"
@pytest.mark.asyncio
async def test_on_message_parent_and_root_id_none_when_absent() -> None:
channel = _make_feishu_channel()
channel._processed_message_ids.clear()
captured = []
async def _capture(**kwargs):
captured.append(kwargs)
channel._handle_message = _capture
with patch.object(channel, "_add_reaction", return_value=None):
await channel._on_message(_make_feishu_event())
assert len(captured) == 1
meta = captured[0]["metadata"]
assert meta["parent_id"] is None
assert meta["root_id"] is None
@pytest.mark.asyncio
async def test_on_message_prepends_reply_context_when_parent_id_present() -> None:
channel = _make_feishu_channel()
channel._processed_message_ids.clear()
channel._client.im.v1.message.get.return_value = _make_get_message_response("original question")
captured = []
async def _capture(**kwargs):
captured.append(kwargs)
channel._handle_message = _capture
with patch.object(channel, "_add_reaction", return_value=None):
await channel._on_message(
_make_feishu_event(
content='{"text": "my answer"}',
parent_id="om_parent",
)
)
assert len(captured) == 1
content = captured[0]["content"]
assert content.startswith("[Reply to: original question]")
assert "my answer" in content
@pytest.mark.asyncio
async def test_on_message_no_extra_api_call_when_no_parent_id() -> None:
channel = _make_feishu_channel()
channel._processed_message_ids.clear()
captured = []
async def _capture(**kwargs):
captured.append(kwargs)
channel._handle_message = _capture
with patch.object(channel, "_add_reaction", return_value=None):
await channel._on_message(_make_feishu_event())
channel._client.im.v1.message.get.assert_not_called()
assert len(captured) == 1
@@ -0,0 +1,258 @@
"""Tests for Feishu streaming (send_delta) via CardKit streaming API."""
import time
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from nanobot.bus.queue import MessageBus
from nanobot.channels.feishu import FeishuChannel, FeishuConfig, _FeishuStreamBuf
def _make_channel(streaming: bool = True) -> FeishuChannel:
config = FeishuConfig(
enabled=True,
app_id="cli_test",
app_secret="secret",
allow_from=["*"],
streaming=streaming,
)
ch = FeishuChannel(config, MessageBus())
ch._client = MagicMock()
ch._loop = None
return ch
def _mock_create_card_response(card_id: str = "card_stream_001"):
resp = MagicMock()
resp.success.return_value = True
resp.data = SimpleNamespace(card_id=card_id)
return resp
def _mock_send_response(message_id: str = "om_stream_001"):
resp = MagicMock()
resp.success.return_value = True
resp.data = SimpleNamespace(message_id=message_id)
return resp
def _mock_content_response(success: bool = True):
resp = MagicMock()
resp.success.return_value = success
resp.code = 0 if success else 99999
resp.msg = "ok" if success else "error"
return resp
class TestFeishuStreamingConfig:
def test_streaming_default_true(self):
assert FeishuConfig().streaming is True
def test_supports_streaming_when_enabled(self):
ch = _make_channel(streaming=True)
assert ch.supports_streaming is True
def test_supports_streaming_disabled(self):
ch = _make_channel(streaming=False)
assert ch.supports_streaming is False
class TestCreateStreamingCard:
def test_returns_card_id_on_success(self):
ch = _make_channel()
ch._client.cardkit.v1.card.create.return_value = _mock_create_card_response("card_123")
ch._client.im.v1.message.create.return_value = _mock_send_response()
result = ch._create_streaming_card_sync("chat_id", "oc_chat1")
assert result == "card_123"
ch._client.cardkit.v1.card.create.assert_called_once()
ch._client.im.v1.message.create.assert_called_once()
def test_returns_none_on_failure(self):
ch = _make_channel()
resp = MagicMock()
resp.success.return_value = False
resp.code = 99999
resp.msg = "error"
ch._client.cardkit.v1.card.create.return_value = resp
assert ch._create_streaming_card_sync("chat_id", "oc_chat1") is None
def test_returns_none_on_exception(self):
ch = _make_channel()
ch._client.cardkit.v1.card.create.side_effect = RuntimeError("network")
assert ch._create_streaming_card_sync("chat_id", "oc_chat1") is None
def test_returns_none_when_card_send_fails(self):
ch = _make_channel()
ch._client.cardkit.v1.card.create.return_value = _mock_create_card_response("card_123")
resp = MagicMock()
resp.success.return_value = False
resp.code = 99999
resp.msg = "error"
resp.get_log_id.return_value = "log1"
ch._client.im.v1.message.create.return_value = resp
assert ch._create_streaming_card_sync("chat_id", "oc_chat1") is None
class TestCloseStreamingMode:
def test_returns_true_on_success(self):
ch = _make_channel()
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response(True)
assert ch._close_streaming_mode_sync("card_1", 10) is True
def test_returns_false_on_failure(self):
ch = _make_channel()
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response(False)
assert ch._close_streaming_mode_sync("card_1", 10) is False
def test_returns_false_on_exception(self):
ch = _make_channel()
ch._client.cardkit.v1.card.settings.side_effect = RuntimeError("err")
assert ch._close_streaming_mode_sync("card_1", 10) is False
class TestStreamUpdateText:
def test_returns_true_on_success(self):
ch = _make_channel()
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(True)
assert ch._stream_update_text_sync("card_1", "hello", 1) is True
def test_returns_false_on_failure(self):
ch = _make_channel()
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(False)
assert ch._stream_update_text_sync("card_1", "hello", 1) is False
def test_returns_false_on_exception(self):
ch = _make_channel()
ch._client.cardkit.v1.card_element.content.side_effect = RuntimeError("err")
assert ch._stream_update_text_sync("card_1", "hello", 1) is False
class TestSendDelta:
@pytest.mark.asyncio
async def test_first_delta_creates_card_and_sends(self):
ch = _make_channel()
ch._client.cardkit.v1.card.create.return_value = _mock_create_card_response("card_new")
ch._client.im.v1.message.create.return_value = _mock_send_response("om_new")
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
await ch.send_delta("oc_chat1", "Hello ")
assert "oc_chat1" in ch._stream_bufs
buf = ch._stream_bufs["oc_chat1"]
assert buf.text == "Hello "
assert buf.card_id == "card_new"
assert buf.sequence == 1
ch._client.cardkit.v1.card.create.assert_called_once()
ch._client.im.v1.message.create.assert_called_once()
ch._client.cardkit.v1.card_element.content.assert_called_once()
@pytest.mark.asyncio
async def test_second_delta_within_interval_skips_update(self):
ch = _make_channel()
buf = _FeishuStreamBuf(text="Hello ", card_id="card_1", sequence=1, last_edit=time.monotonic())
ch._stream_bufs["oc_chat1"] = buf
await ch.send_delta("oc_chat1", "world")
assert buf.text == "Hello world"
ch._client.cardkit.v1.card_element.content.assert_not_called()
@pytest.mark.asyncio
async def test_delta_after_interval_updates_text(self):
ch = _make_channel()
buf = _FeishuStreamBuf(text="Hello ", card_id="card_1", sequence=1, last_edit=time.monotonic() - 1.0)
ch._stream_bufs["oc_chat1"] = buf
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
await ch.send_delta("oc_chat1", "world")
assert buf.text == "Hello world"
assert buf.sequence == 2
ch._client.cardkit.v1.card_element.content.assert_called_once()
@pytest.mark.asyncio
async def test_stream_end_sends_final_update(self):
ch = _make_channel()
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
text="Final content", card_id="card_1", sequence=3, last_edit=0.0,
)
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response()
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
assert "oc_chat1" not in ch._stream_bufs
ch._client.cardkit.v1.card_element.content.assert_called_once()
ch._client.cardkit.v1.card.settings.assert_called_once()
settings_call = ch._client.cardkit.v1.card.settings.call_args[0][0]
assert settings_call.body.sequence == 5 # after final content seq 4
@pytest.mark.asyncio
async def test_stream_end_fallback_when_no_card_id(self):
"""If card creation failed, stream_end falls back to a plain card message."""
ch = _make_channel()
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
text="Fallback content", card_id=None, sequence=0, last_edit=0.0,
)
ch._client.im.v1.message.create.return_value = _mock_send_response("om_fb")
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
assert "oc_chat1" not in ch._stream_bufs
ch._client.cardkit.v1.card_element.content.assert_not_called()
ch._client.im.v1.message.create.assert_called_once()
@pytest.mark.asyncio
async def test_stream_end_without_buf_is_noop(self):
ch = _make_channel()
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
ch._client.cardkit.v1.card_element.content.assert_not_called()
@pytest.mark.asyncio
async def test_empty_delta_skips_send(self):
ch = _make_channel()
await ch.send_delta("oc_chat1", " ")
assert "oc_chat1" in ch._stream_bufs
ch._client.cardkit.v1.card.create.assert_not_called()
@pytest.mark.asyncio
async def test_no_client_returns_early(self):
ch = _make_channel()
ch._client = None
await ch.send_delta("oc_chat1", "text")
assert "oc_chat1" not in ch._stream_bufs
@pytest.mark.asyncio
async def test_sequence_increments_correctly(self):
ch = _make_channel()
buf = _FeishuStreamBuf(text="a", card_id="card_1", sequence=5, last_edit=0.0)
ch._stream_bufs["oc_chat1"] = buf
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
await ch.send_delta("oc_chat1", "b")
assert buf.sequence == 6
buf.last_edit = 0.0 # reset to bypass throttle
await ch.send_delta("oc_chat1", "c")
assert buf.sequence == 7
class TestSendMessageReturnsId:
def test_returns_message_id_on_success(self):
ch = _make_channel()
ch._client.im.v1.message.create.return_value = _mock_send_response("om_abc")
result = ch._send_message_sync("chat_id", "oc_chat1", "text", '{"text":"hi"}')
assert result == "om_abc"
def test_returns_none_on_failure(self):
ch = _make_channel()
resp = MagicMock()
resp.success.return_value = False
resp.code = 99999
resp.msg = "error"
resp.get_log_id.return_value = "log1"
ch._client.im.v1.message.create.return_value = resp
result = ch._send_message_sync("chat_id", "oc_chat1", "text", '{"text":"hi"}')
assert result is None
@@ -0,0 +1,115 @@
"""Tests for FeishuChannel._split_elements_by_table_limit.
Feishu cards reject messages that contain more than one table element
(API error 11310: card table number over limit). The helper splits a flat
list of card elements into groups so that each group contains at most one
table, allowing nanobot to send multiple cards instead of failing.
"""
# Check optional Feishu dependencies before running tests
try:
from nanobot.channels import feishu
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
except ImportError:
FEISHU_AVAILABLE = False
if not FEISHU_AVAILABLE:
import pytest
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
from nanobot.channels.feishu import FeishuChannel
def _md(text: str) -> dict:
return {"tag": "markdown", "content": text}
def _table() -> dict:
return {
"tag": "table",
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
"rows": [{"c0": "v"}],
"page_size": 2,
}
split = FeishuChannel._split_elements_by_table_limit
def test_empty_list_returns_single_empty_group() -> None:
assert split([]) == [[]]
def test_no_tables_returns_single_group() -> None:
els = [_md("hello"), _md("world")]
result = split(els)
assert result == [els]
def test_single_table_stays_in_one_group() -> None:
els = [_md("intro"), _table(), _md("outro")]
result = split(els)
assert len(result) == 1
assert result[0] == els
def test_two_tables_split_into_two_groups() -> None:
# Use different row values so the two tables are not equal
t1 = {
"tag": "table",
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
"rows": [{"c0": "table-one"}],
"page_size": 2,
}
t2 = {
"tag": "table",
"columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
"rows": [{"c0": "table-two"}],
"page_size": 2,
}
els = [_md("before"), t1, _md("between"), t2, _md("after")]
result = split(els)
assert len(result) == 2
# First group: text before table-1 + table-1
assert t1 in result[0]
assert t2 not in result[0]
# Second group: text between tables + table-2 + text after
assert t2 in result[1]
assert t1 not in result[1]
def test_three_tables_split_into_three_groups() -> None:
tables = [
{"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
for i in range(3)
]
els = tables[:]
result = split(els)
assert len(result) == 3
for i, group in enumerate(result):
assert tables[i] in group
def test_leading_markdown_stays_with_first_table() -> None:
intro = _md("intro")
t = _table()
result = split([intro, t])
assert len(result) == 1
assert result[0] == [intro, t]
def test_trailing_markdown_after_second_table() -> None:
t1, t2 = _table(), _table()
tail = _md("end")
result = split([t1, t2, tail])
assert len(result) == 2
assert result[1] == [t2, tail]
def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
head = _md("head")
t1, t2 = _table(), _table()
result = split([head, t1, t2])
# head + t1 in group 0; t2 in group 1
assert result[0] == [head, t1]
assert result[1] == [t2]
@@ -0,0 +1,148 @@
"""Tests for FeishuChannel tool hint code block formatting."""
import json
from unittest.mock import MagicMock, patch
import pytest
from pytest import mark
# Check optional Feishu dependencies before running tests
try:
from nanobot.channels import feishu
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
except ImportError:
FEISHU_AVAILABLE = False
if not FEISHU_AVAILABLE:
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.channels.feishu import FeishuChannel
@pytest.fixture
def mock_feishu_channel():
"""Create a FeishuChannel with mocked client."""
config = MagicMock()
config.app_id = "test_app_id"
config.app_secret = "test_app_secret"
config.encrypt_key = None
config.verification_token = None
bus = MagicMock()
channel = FeishuChannel(config, bus)
channel._client = MagicMock() # Simulate initialized client
return channel
@mark.asyncio
async def test_tool_hint_sends_code_message(mock_feishu_channel):
"""Tool hint messages should be sent as interactive cards with code blocks."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='web_search("test query")',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
# Verify interactive message with card was sent
assert mock_send.call_count == 1
call_args = mock_send.call_args[0]
receive_id_type, receive_id, msg_type, content = call_args
assert receive_id_type == "chat_id"
assert receive_id == "oc_123456"
assert msg_type == "interactive"
# Parse content to verify card structure
card = json.loads(content)
assert card["config"]["wide_screen_mode"] is True
assert len(card["elements"]) == 1
assert card["elements"][0]["tag"] == "markdown"
# Check that code block is properly formatted with language hint
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"test query\")\n```"
assert card["elements"][0]["content"] == expected_md
@mark.asyncio
async def test_tool_hint_empty_content_does_not_send(mock_feishu_channel):
"""Empty tool hint messages should not be sent."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content=" ", # whitespace only
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
# Should not send any message
mock_send.assert_not_called()
@mark.asyncio
async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
"""Regular messages without _tool_hint should use normal formatting."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content="Hello, world!",
metadata={}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
# Should send as text message (detected format)
assert mock_send.call_count == 1
call_args = mock_send.call_args[0]
_, _, msg_type, content = call_args
assert msg_type == "text"
assert json.loads(content) == {"text": "Hello, world!"}
@mark.asyncio
async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
"""Multiple tool calls should be displayed each on its own line in a code block."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='web_search("query"), read_file("/path/to/file")',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
call_args = mock_send.call_args[0]
msg_type = call_args[2]
content = json.loads(call_args[3])
assert msg_type == "interactive"
# Each tool call should be on its own line
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"query\"),\nread_file(\"/path/to/file\")\n```"
assert content["elements"][0]["content"] == expected_md
@mark.asyncio
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
"""Commas inside a single tool argument must not be split onto a new line."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='web_search("foo, bar"), read_file("/path/to/file")',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
expected_md = (
"**Tool Calls**\n\n```text\n"
"web_search(\"foo, bar\"),\n"
"read_file(\"/path/to/file\")\n```"
)
assert content["elements"][0]["content"] == expected_md
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,172 @@
import tempfile
from pathlib import Path
from types import SimpleNamespace
import pytest
# Check optional QQ dependencies before running tests
try:
from nanobot.channels import qq
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
except ImportError:
QQ_AVAILABLE = False
if not QQ_AVAILABLE:
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import QQChannel, QQConfig
class _FakeApi:
def __init__(self) -> None:
self.c2c_calls: list[dict] = []
self.group_calls: list[dict] = []
async def post_c2c_message(self, **kwargs) -> None:
self.c2c_calls.append(kwargs)
async def post_group_message(self, **kwargs) -> None:
self.group_calls.append(kwargs)
class _FakeClient:
def __init__(self) -> None:
self.api = _FakeApi()
@pytest.mark.asyncio
async def test_on_group_message_routes_to_group_chat_id() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus())
data = SimpleNamespace(
id="msg1",
content="hello",
group_openid="group123",
author=SimpleNamespace(member_openid="user1"),
attachments=[],
)
await channel._on_message(data, is_group=True)
msg = await channel.bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "group123"
@pytest.mark.asyncio
async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
await channel.send(
OutboundMessage(
channel="qq",
chat_id="group123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call == {
"group_openid": "group123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.c2c_calls
@pytest.mark.asyncio
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.c2c_calls) == 1
call = channel._client.api.c2c_calls[0]
assert call == {
"openid": "user123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.group_calls
@pytest.mark.asyncio
async def test_send_group_message_uses_markdown_when_configured() -> None:
channel = QQChannel(
QQConfig(app_id="app", secret="secret", allow_from=["*"], msg_format="markdown"),
MessageBus(),
)
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
await channel.send(
OutboundMessage(
channel="qq",
chat_id="group123",
content="**hello**",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call == {
"group_openid": "group123",
"msg_type": 2,
"markdown": {"content": "**hello**"},
"msg_id": "msg1",
"msg_seq": 2,
}
@pytest.mark.asyncio
async def test_read_media_bytes_local_path() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp_path = f.name
data, filename = await channel._read_media_bytes(tmp_path)
assert data == b"\x89PNG\r\n"
assert filename == Path(tmp_path).name
@pytest.mark.asyncio
async def test_read_media_bytes_file_uri() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
f.write(b"JFIF")
tmp_path = f.name
data, filename = await channel._read_media_bytes(f"file://{tmp_path}")
assert data == b"JFIF"
assert filename == Path(tmp_path).name
@pytest.mark.asyncio
async def test_read_media_bytes_missing_file() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
data, filename = await channel._read_media_bytes("/nonexistent/path/image.png")
assert data is None
assert filename is None
@@ -0,0 +1,153 @@
from __future__ import annotations
import pytest
# Check optional Slack dependencies before running tests
try:
import slack_sdk # noqa: F401
except ImportError:
pytest.skip("Slack dependencies not installed (slack-sdk)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.slack import SlackChannel
from nanobot.channels.slack import SlackConfig
class _FakeAsyncWebClient:
def __init__(self) -> None:
self.chat_post_calls: list[dict[str, object | None]] = []
self.file_upload_calls: list[dict[str, object | None]] = []
self.reactions_add_calls: list[dict[str, object | None]] = []
self.reactions_remove_calls: list[dict[str, object | None]] = []
async def chat_postMessage(
self,
*,
channel: str,
text: str,
thread_ts: str | None = None,
) -> None:
self.chat_post_calls.append(
{
"channel": channel,
"text": text,
"thread_ts": thread_ts,
}
)
async def files_upload_v2(
self,
*,
channel: str,
file: str,
thread_ts: str | None = None,
) -> None:
self.file_upload_calls.append(
{
"channel": channel,
"file": file,
"thread_ts": thread_ts,
}
)
async def reactions_add(
self,
*,
channel: str,
name: str,
timestamp: str,
) -> None:
self.reactions_add_calls.append(
{
"channel": channel,
"name": name,
"timestamp": timestamp,
}
)
async def reactions_remove(
self,
*,
channel: str,
name: str,
timestamp: str,
) -> None:
self.reactions_remove_calls.append(
{
"channel": channel,
"name": name,
"timestamp": timestamp,
}
)
@pytest.mark.asyncio
async def test_send_uses_thread_for_channel_messages() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="C123",
content="hello",
media=["/tmp/demo.txt"],
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "channel"}},
)
)
assert len(fake_web.chat_post_calls) == 1
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
assert fake_web.chat_post_calls[0]["thread_ts"] == "1700000000.000100"
assert len(fake_web.file_upload_calls) == 1
assert fake_web.file_upload_calls[0]["thread_ts"] == "1700000000.000100"
@pytest.mark.asyncio
async def test_send_omits_thread_for_dm_messages() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="D123",
content="hello",
media=["/tmp/demo.txt"],
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "im"}},
)
)
assert len(fake_web.chat_post_calls) == 1
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
assert fake_web.chat_post_calls[0]["thread_ts"] is None
assert len(fake_web.file_upload_calls) == 1
assert fake_web.file_upload_calls[0]["thread_ts"] is None
@pytest.mark.asyncio
async def test_send_updates_reaction_when_final_response_sent() -> None:
channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="C123",
content="done",
metadata={
"slack": {"event": {"ts": "1700000000.000100"}, "channel_type": "channel"},
},
)
)
assert fake_web.reactions_remove_calls == [
{"channel": "C123", "name": "eyes", "timestamp": "1700000000.000100"}
]
assert fake_web.reactions_add_calls == [
{"channel": "C123", "name": "white_check_mark", "timestamp": "1700000000.000100"}
]
@@ -0,0 +1,966 @@
import asyncio
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
# Check optional Telegram dependencies before running tests
try:
import telegram # noqa: F401
except ImportError:
pytest.skip("Telegram dependencies not installed (python-telegram-bot)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel, _StreamBuf
from nanobot.channels.telegram import TelegramConfig
class _FakeHTTPXRequest:
instances: list["_FakeHTTPXRequest"] = []
def __init__(self, **kwargs) -> None:
self.kwargs = kwargs
self.__class__.instances.append(self)
@classmethod
def clear(cls) -> None:
cls.instances.clear()
class _FakeUpdater:
def __init__(self, on_start_polling) -> None:
self._on_start_polling = on_start_polling
async def start_polling(self, **kwargs) -> None:
self._on_start_polling()
class _FakeBot:
def __init__(self) -> None:
self.sent_messages: list[dict] = []
self.sent_media: list[dict] = []
self.get_me_calls = 0
async def get_me(self):
self.get_me_calls += 1
return SimpleNamespace(id=999, username="nanobot_test")
async def set_my_commands(self, commands) -> None:
self.commands = commands
async def send_message(self, **kwargs):
self.sent_messages.append(kwargs)
return SimpleNamespace(message_id=len(self.sent_messages))
async def send_photo(self, **kwargs) -> None:
self.sent_media.append({"kind": "photo", **kwargs})
async def send_voice(self, **kwargs) -> None:
self.sent_media.append({"kind": "voice", **kwargs})
async def send_audio(self, **kwargs) -> None:
self.sent_media.append({"kind": "audio", **kwargs})
async def send_document(self, **kwargs) -> None:
self.sent_media.append({"kind": "document", **kwargs})
async def send_chat_action(self, **kwargs) -> None:
pass
async def get_file(self, file_id: str):
"""Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
async def _fake_download(path) -> None:
pass
return SimpleNamespace(download_to_drive=_fake_download)
class _FakeApp:
def __init__(self, on_start_polling) -> None:
self.bot = _FakeBot()
self.updater = _FakeUpdater(on_start_polling)
self.handlers = []
self.error_handlers = []
def add_error_handler(self, handler) -> None:
self.error_handlers.append(handler)
def add_handler(self, handler) -> None:
self.handlers.append(handler)
async def initialize(self) -> None:
pass
async def start(self) -> None:
pass
class _FakeBuilder:
def __init__(self, app: _FakeApp) -> None:
self.app = app
self.token_value = None
self.request_value = None
self.get_updates_request_value = None
def token(self, token: str):
self.token_value = token
return self
def request(self, request):
self.request_value = request
return self
def get_updates_request(self, request):
self.get_updates_request_value = request
return self
def proxy(self, _proxy):
raise AssertionError("builder.proxy should not be called when request is set")
def get_updates_proxy(self, _proxy):
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
def build(self):
return self.app
def _make_telegram_update(
*,
chat_type: str = "group",
text: str | None = None,
caption: str | None = None,
entities=None,
caption_entities=None,
reply_to_message=None,
):
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
message = SimpleNamespace(
chat=SimpleNamespace(type=chat_type, is_forum=False),
chat_id=-100123,
text=text,
caption=caption,
entities=entities or [],
caption_entities=caption_entities or [],
reply_to_message=reply_to_message,
photo=None,
voice=None,
audio=None,
document=None,
media_group_id=None,
message_thread_id=None,
message_id=1,
)
return SimpleNamespace(message=message, effective_user=user)
@pytest.mark.asyncio
async def test_start_creates_separate_pools_with_proxy(monkeypatch) -> None:
_FakeHTTPXRequest.clear()
config = TelegramConfig(
enabled=True,
token="123:abc",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
)
bus = MessageBus()
channel = TelegramChannel(config, bus)
app = _FakeApp(lambda: setattr(channel, "_running", False))
builder = _FakeBuilder(app)
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
monkeypatch.setattr(
"nanobot.channels.telegram.Application",
SimpleNamespace(builder=lambda: builder),
)
await channel.start()
assert len(_FakeHTTPXRequest.instances) == 2
api_req, poll_req = _FakeHTTPXRequest.instances
assert api_req.kwargs["proxy"] == config.proxy
assert poll_req.kwargs["proxy"] == config.proxy
assert api_req.kwargs["connection_pool_size"] == 32
assert poll_req.kwargs["connection_pool_size"] == 4
assert builder.request_value is api_req
assert builder.get_updates_request_value is poll_req
assert any(cmd.command == "status" for cmd in app.bot.commands)
@pytest.mark.asyncio
async def test_start_respects_custom_pool_config(monkeypatch) -> None:
_FakeHTTPXRequest.clear()
config = TelegramConfig(
enabled=True,
token="123:abc",
allow_from=["*"],
connection_pool_size=32,
pool_timeout=10.0,
)
bus = MessageBus()
channel = TelegramChannel(config, bus)
app = _FakeApp(lambda: setattr(channel, "_running", False))
builder = _FakeBuilder(app)
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
monkeypatch.setattr(
"nanobot.channels.telegram.Application",
SimpleNamespace(builder=lambda: builder),
)
await channel.start()
api_req = _FakeHTTPXRequest.instances[0]
poll_req = _FakeHTTPXRequest.instances[1]
assert api_req.kwargs["connection_pool_size"] == 32
assert api_req.kwargs["pool_timeout"] == 10.0
assert poll_req.kwargs["pool_timeout"] == 10.0
@pytest.mark.asyncio
async def test_send_text_retries_on_timeout() -> None:
"""_send_text retries on TimedOut before succeeding."""
from telegram.error import TimedOut
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
call_count = 0
original_send = channel._app.bot.send_message
async def flaky_send(**kwargs):
nonlocal call_count
call_count += 1
if call_count <= 2:
raise TimedOut()
return await original_send(**kwargs)
channel._app.bot.send_message = flaky_send
import nanobot.channels.telegram as tg_mod
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
try:
await channel._send_text(123, "hello", None, {})
finally:
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
assert call_count == 3
assert len(channel._app.bot.sent_messages) == 1
@pytest.mark.asyncio
async def test_send_text_gives_up_after_max_retries() -> None:
"""_send_text raises TimedOut after exhausting all retries."""
from telegram.error import TimedOut
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
async def always_timeout(**kwargs):
raise TimedOut()
channel._app.bot.send_message = always_timeout
import nanobot.channels.telegram as tg_mod
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
try:
with pytest.raises(TimedOut):
await channel._send_text(123, "hello", None, {})
finally:
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
assert channel._app.bot.sent_messages == []
@pytest.mark.asyncio
async def test_on_error_logs_network_issues_as_warning(monkeypatch) -> None:
from telegram.error import NetworkError
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
recorded: list[tuple[str, str]] = []
monkeypatch.setattr(
"nanobot.channels.telegram.logger.warning",
lambda message, error: recorded.append(("warning", message.format(error))),
)
monkeypatch.setattr(
"nanobot.channels.telegram.logger.error",
lambda message, error: recorded.append(("error", message.format(error))),
)
await channel._on_error(object(), SimpleNamespace(error=NetworkError("proxy disconnected")))
assert recorded == [("warning", "Telegram network issue: proxy disconnected")]
@pytest.mark.asyncio
async def test_on_error_keeps_non_network_exceptions_as_error(monkeypatch) -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
recorded: list[tuple[str, str]] = []
monkeypatch.setattr(
"nanobot.channels.telegram.logger.warning",
lambda message, error: recorded.append(("warning", message.format(error))),
)
monkeypatch.setattr(
"nanobot.channels.telegram.logger.error",
lambda message, error: recorded.append(("error", message.format(error))),
)
await channel._on_error(object(), SimpleNamespace(error=RuntimeError("boom")))
assert recorded == [("error", "Telegram error: boom")]
@pytest.mark.asyncio
async def test_send_delta_stream_end_raises_and_keeps_buffer_on_failure() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.edit_message_text = AsyncMock(side_effect=RuntimeError("boom"))
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
with pytest.raises(RuntimeError, match="boom"):
await channel.send_delta("123", "", {"_stream_end": True})
assert "123" in channel._stream_bufs
@pytest.mark.asyncio
async def test_send_delta_stream_end_treats_not_modified_as_success() -> None:
from telegram.error import BadRequest
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified"))
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0")
await channel.send_delta("123", "", {"_stream_end": True, "_stream_id": "s:0"})
assert "123" not in channel._stream_bufs
@pytest.mark.asyncio
async def test_send_delta_new_stream_id_replaces_stale_buffer() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._stream_bufs["123"] = _StreamBuf(
text="hello",
message_id=7,
last_edit=0.0,
stream_id="old:0",
)
await channel.send_delta("123", "world", {"_stream_delta": True, "_stream_id": "new:0"})
buf = channel._stream_bufs["123"]
assert buf.text == "world"
assert buf.stream_id == "new:0"
assert buf.message_id == 1
@pytest.mark.asyncio
async def test_send_delta_incremental_edit_treats_not_modified_as_success() -> None:
from telegram.error import BadRequest
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0")
channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified"))
await channel.send_delta("123", "", {"_stream_delta": True, "_stream_id": "s:0"})
assert channel._stream_bufs["123"].last_edit > 0.0
def test_derive_topic_session_key_uses_thread_id() -> None:
message = SimpleNamespace(
chat=SimpleNamespace(type="supergroup"),
chat_id=-100123,
message_thread_id=42,
)
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
def test_get_extension_falls_back_to_original_filename() -> None:
channel = TelegramChannel(TelegramConfig(), MessageBus())
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
def test_telegram_group_policy_defaults_to_mention() -> None:
assert TelegramConfig().group_policy == "mention"
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
assert channel.is_allowed("12345|carol") is True
assert channel.is_allowed("99999|alice") is True
assert channel.is_allowed("67890|bob") is True
def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus())
assert channel.is_allowed("attacker|alice|extra") is False
assert channel.is_allowed("not-a-number|alice") is False
@pytest.mark.asyncio
async def test_send_progress_keeps_message_in_topic() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"_progress": True, "message_thread_id": 42},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
@pytest.mark.asyncio
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
channel._message_threads[("123", 10)] = 42
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"message_id": 10},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
@pytest.mark.asyncio
async def test_send_remote_media_url_after_security_validation(monkeypatch) -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
monkeypatch.setattr("nanobot.channels.telegram.validate_url_target", lambda url: (True, ""))
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="",
media=["https://example.com/cat.jpg"],
)
)
assert channel._app.bot.sent_media == [
{
"kind": "photo",
"chat_id": 123,
"photo": "https://example.com/cat.jpg",
"reply_parameters": None,
}
]
@pytest.mark.asyncio
async def test_send_blocks_unsafe_remote_media_url(monkeypatch) -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
monkeypatch.setattr(
"nanobot.channels.telegram.validate_url_target",
lambda url: (False, "Blocked: example.com resolves to private/internal address 127.0.0.1"),
)
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="",
media=["http://example.com/internal.jpg"],
)
)
assert channel._app.bot.sent_media == []
assert channel._app.bot.sent_messages == [
{
"chat_id": 123,
"text": "[Failed to send: internal.jpg]",
"reply_parameters": None,
}
]
@pytest.mark.asyncio
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
assert handled == []
assert channel._app.bot.get_me_calls == 1
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
mention = SimpleNamespace(type="mention", offset=0, length=13)
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
assert len(handled) == 2
assert channel._app.bot.get_me_calls == 1
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_caption_mention() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
mention = SimpleNamespace(type="mention", offset=0, length=13)
await channel._on_message(
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
None,
)
assert len(handled) == 1
assert handled[0]["content"] == "@nanobot_test photo"
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
assert len(handled) == 1
@pytest.mark.asyncio
async def test_group_policy_open_accepts_plain_group_message() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
await channel._on_message(_make_telegram_update(text="hello group"), None)
assert len(handled) == 1
assert channel._app.bot.get_me_calls == 0
def test_extract_reply_context_no_reply() -> None:
"""When there is no reply_to_message, _extract_reply_context returns None."""
message = SimpleNamespace(reply_to_message=None)
assert TelegramChannel._extract_reply_context(message) is None
def test_extract_reply_context_with_text() -> None:
"""When reply has text, return prefixed string."""
reply = SimpleNamespace(text="Hello world", caption=None)
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
def test_extract_reply_context_with_caption_only() -> None:
"""When reply has only caption (no text), caption is used."""
reply = SimpleNamespace(text=None, caption="Photo caption")
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
def test_extract_reply_context_truncation() -> None:
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
reply = SimpleNamespace(text=long_text, caption=None)
message = SimpleNamespace(reply_to_message=reply)
result = TelegramChannel._extract_reply_context(message)
assert result is not None
assert result.startswith("[Reply to: ")
assert result.endswith("...]")
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
def test_extract_reply_context_no_text_returns_none() -> None:
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
reply = SimpleNamespace(text=None, caption=None)
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) is None
@pytest.mark.asyncio
async def test_on_message_includes_reply_context() -> None:
"""When user replies to a message, content passed to bus starts with reply context."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
update = _make_telegram_update(text="translate this", reply_to_message=reply)
await channel._on_message(update, None)
assert len(handled) == 1
assert handled[0]["content"].startswith("[Reply to: Hello]")
assert "translate this" in handled[0]["content"]
@pytest.mark.asyncio
async def test_download_message_media_returns_path_when_download_succeeds(
monkeypatch, tmp_path
) -> None:
"""_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
msg = SimpleNamespace(
photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
voice=None,
audio=None,
document=None,
video=None,
video_note=None,
animation=None,
)
paths, parts = await channel._download_message_media(msg)
assert len(paths) == 1
assert len(parts) == 1
assert "fid123" in paths[0]
assert "[image:" in parts[0]
@pytest.mark.asyncio
async def test_download_message_media_uses_file_unique_id_when_available(
monkeypatch, tmp_path
) -> None:
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
downloaded: dict[str, str] = {}
async def _download_to_drive(path: str) -> None:
downloaded["path"] = path
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=_download_to_drive)
)
channel._app = app
msg = SimpleNamespace(
photo=[
SimpleNamespace(
file_id="file-id-that-should-not-be-used",
file_unique_id="stable-unique-id",
mime_type="image/jpeg",
file_name=None,
)
],
voice=None,
audio=None,
document=None,
video=None,
video_note=None,
animation=None,
)
paths, parts = await channel._download_message_media(msg)
assert downloaded["path"].endswith("stable-unique-id.jpg")
assert paths == [str(media_dir / "stable-unique-id.jpg")]
assert parts == [f"[image: {media_dir / 'stable-unique-id.jpg'}]"]
@pytest.mark.asyncio
async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
"""When user replies to a message with media, that media is downloaded and attached to the turn."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
channel._app = app
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_photo = SimpleNamespace(
text=None,
caption=None,
photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(
text="what is the image?",
reply_to_message=reply_with_photo,
)
await channel._on_message(update, None)
assert len(handled) == 1
assert handled[0]["content"].startswith("[Reply to: [image:")
assert "what is the image?" in handled[0]["content"]
assert len(handled[0]["media"]) == 1
assert "reply_photo_fid" in handled[0]["media"][0]
@pytest.mark.asyncio
async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
"""When reply has media but download fails, no media attached and no reply tag."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.get_file = None
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_photo = SimpleNamespace(
text=None,
caption=None,
photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
await channel._on_message(update, None)
assert len(handled) == 1
assert "what is this?" in handled[0]["content"]
assert handled[0]["media"] == []
@pytest.mark.asyncio
async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
"""When replying to a message with caption + photo, both text context and media are included."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
channel._app = app
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_caption_and_photo = SimpleNamespace(
text=None,
caption="A cute cat",
photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(
text="what breed is this?",
reply_to_message=reply_with_caption_and_photo,
)
await channel._on_message(update, None)
assert len(handled) == 1
assert "[Reply to: A cute cat]" in handled[0]["content"]
assert "what breed is this?" in handled[0]["content"]
assert len(handled[0]["media"]) == 1
assert "cat_fid" in handled[0]["media"][0]
@pytest.mark.asyncio
async def test_forward_command_does_not_inject_reply_context() -> None:
"""Slash commands forwarded via _forward_command must not include reply context."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
update = _make_telegram_update(text="/new", reply_to_message=reply)
await channel._forward_command(update, None)
assert len(handled) == 1
assert handled[0]["content"] == "/new"
@pytest.mark.asyncio
async def test_on_help_includes_restart_command() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
update = _make_telegram_update(text="/help", chat_type="private")
update.message.reply_text = AsyncMock()
await channel._on_help(update, None)
update.message.reply_text.assert_awaited_once()
help_text = update.message.reply_text.await_args.args[0]
assert "/restart" in help_text
assert "/status" in help_text
@@ -0,0 +1,280 @@
import asyncio
import json
import tempfile
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from nanobot.bus.queue import MessageBus
from nanobot.channels.weixin import (
ITEM_IMAGE,
ITEM_TEXT,
MESSAGE_TYPE_BOT,
WEIXIN_CHANNEL_VERSION,
WeixinChannel,
WeixinConfig,
)
def _make_channel() -> tuple[WeixinChannel, MessageBus]:
bus = MessageBus()
channel = WeixinChannel(
WeixinConfig(
enabled=True,
allow_from=["*"],
state_dir=tempfile.mkdtemp(prefix="nanobot-weixin-test-"),
),
bus,
)
return channel, bus
def test_make_headers_includes_route_tag_when_configured() -> None:
bus = MessageBus()
channel = WeixinChannel(
WeixinConfig(enabled=True, allow_from=["*"], route_tag=123),
bus,
)
channel._token = "token"
headers = channel._make_headers()
assert headers["Authorization"] == "Bearer token"
assert headers["SKRouteTag"] == "123"
def test_channel_version_matches_reference_plugin_version() -> None:
assert WEIXIN_CHANNEL_VERSION == "1.0.3"
def test_save_and_load_state_persists_context_tokens(tmp_path) -> None:
bus = MessageBus()
channel = WeixinChannel(
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
bus,
)
channel._token = "token"
channel._get_updates_buf = "cursor"
channel._context_tokens = {"wx-user": "ctx-1"}
channel._save_state()
saved = json.loads((tmp_path / "account.json").read_text())
assert saved["context_tokens"] == {"wx-user": "ctx-1"}
restored = WeixinChannel(
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
bus,
)
assert restored._load_state() is True
assert restored._context_tokens == {"wx-user": "ctx-1"}
@pytest.mark.asyncio
async def test_process_message_deduplicates_inbound_ids() -> None:
channel, bus = _make_channel()
msg = {
"message_type": 1,
"message_id": "m1",
"from_user_id": "wx-user",
"context_token": "ctx-1",
"item_list": [
{"type": ITEM_TEXT, "text_item": {"text": "hello"}},
],
}
await channel._process_message(msg)
first = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
await channel._process_message(msg)
assert first.sender_id == "wx-user"
assert first.chat_id == "wx-user"
assert first.content == "hello"
assert bus.inbound_size == 0
@pytest.mark.asyncio
async def test_process_message_caches_context_token_and_send_uses_it() -> None:
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._send_text = AsyncMock()
await channel._process_message(
{
"message_type": 1,
"message_id": "m2",
"from_user_id": "wx-user",
"context_token": "ctx-2",
"item_list": [
{"type": ITEM_TEXT, "text_item": {"text": "ping"}},
],
}
)
await channel.send(
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
)
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2")
@pytest.mark.asyncio
async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None:
bus = MessageBus()
channel = WeixinChannel(
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
bus,
)
await channel._process_message(
{
"message_type": 1,
"message_id": "m2b",
"from_user_id": "wx-user",
"context_token": "ctx-2b",
"item_list": [
{"type": ITEM_TEXT, "text_item": {"text": "ping"}},
],
}
)
saved = json.loads((tmp_path / "account.json").read_text())
assert saved["context_tokens"] == {"wx-user": "ctx-2b"}
@pytest.mark.asyncio
async def test_process_message_extracts_media_and_preserves_paths() -> None:
channel, bus = _make_channel()
channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg")
await channel._process_message(
{
"message_type": 1,
"message_id": "m3",
"from_user_id": "wx-user",
"context_token": "ctx-3",
"item_list": [
{"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "x"}}},
],
}
)
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
assert "[image]" in inbound.content
assert "/tmp/test.jpg" in inbound.content
assert inbound.media == ["/tmp/test.jpg"]
@pytest.mark.asyncio
async def test_send_without_context_token_does_not_send_text() -> None:
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._send_text = AsyncMock()
await channel.send(
type("Msg", (), {"chat_id": "unknown-user", "content": "pong", "media": [], "metadata": {}})()
)
channel._send_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_send_does_not_send_when_session_is_paused() -> None:
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-2"
channel._pause_session(60)
channel._send_text = AsyncMock()
await channel.send(
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
)
channel._send_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_poll_once_pauses_session_on_expired_errcode() -> None:
channel, _bus = _make_channel()
channel._client = SimpleNamespace(timeout=None)
channel._token = "token"
channel._api_post = AsyncMock(return_value={"ret": 0, "errcode": -14, "errmsg": "expired"})
await channel._poll_once()
assert channel._session_pause_remaining_s() > 0
@pytest.mark.asyncio
async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None:
channel, _bus = _make_channel()
channel._running = True
channel._save_state = lambda: None
channel._print_qr_code = lambda url: None
channel._api_get = AsyncMock(
side_effect=[
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
{"status": "expired"},
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
{
"status": "confirmed",
"bot_token": "token-2",
"ilink_bot_id": "bot-2",
"baseurl": "https://example.test",
"ilink_user_id": "wx-user",
},
]
)
ok = await channel._qr_login()
assert ok is True
assert channel._token == "token-2"
assert channel.config.base_url == "https://example.test"
@pytest.mark.asyncio
async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
channel, _bus = _make_channel()
channel._running = True
channel._print_qr_code = lambda url: None
channel._api_get = AsyncMock(
side_effect=[
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
{"status": "expired"},
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
{"status": "expired"},
{"qrcode": "qr-3", "qrcode_img_content": "url-3"},
{"status": "expired"},
{"qrcode": "qr-4", "qrcode_img_content": "url-4"},
{"status": "expired"},
]
)
ok = await channel._qr_login()
assert ok is False
@pytest.mark.asyncio
async def test_process_message_skips_bot_messages() -> None:
channel, bus = _make_channel()
await channel._process_message(
{
"message_type": MESSAGE_TYPE_BOT,
"message_id": "m4",
"from_user_id": "wx-user",
"item_list": [
{"type": ITEM_TEXT, "text_item": {"text": "hello"}},
],
}
)
assert bus.inbound_size == 0
@@ -0,0 +1,157 @@
"""Tests for WhatsApp channel outbound media support."""
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.channels.whatsapp import WhatsAppChannel
def _make_channel() -> WhatsAppChannel:
bus = MagicMock()
ch = WhatsAppChannel({"enabled": True}, bus)
ch._ws = AsyncMock()
ch._connected = True
return ch
@pytest.mark.asyncio
async def test_send_text_only():
ch = _make_channel()
msg = OutboundMessage(channel="whatsapp", chat_id="123@s.whatsapp.net", content="hello")
await ch.send(msg)
ch._ws.send.assert_called_once()
payload = json.loads(ch._ws.send.call_args[0][0])
assert payload["type"] == "send"
assert payload["text"] == "hello"
@pytest.mark.asyncio
async def test_send_media_dispatches_send_media_command():
ch = _make_channel()
msg = OutboundMessage(
channel="whatsapp",
chat_id="123@s.whatsapp.net",
content="check this out",
media=["/tmp/photo.jpg"],
)
await ch.send(msg)
assert ch._ws.send.call_count == 2
text_payload = json.loads(ch._ws.send.call_args_list[0][0][0])
media_payload = json.loads(ch._ws.send.call_args_list[1][0][0])
assert text_payload["type"] == "send"
assert text_payload["text"] == "check this out"
assert media_payload["type"] == "send_media"
assert media_payload["filePath"] == "/tmp/photo.jpg"
assert media_payload["mimetype"] == "image/jpeg"
assert media_payload["fileName"] == "photo.jpg"
@pytest.mark.asyncio
async def test_send_media_only_no_text():
ch = _make_channel()
msg = OutboundMessage(
channel="whatsapp",
chat_id="123@s.whatsapp.net",
content="",
media=["/tmp/doc.pdf"],
)
await ch.send(msg)
ch._ws.send.assert_called_once()
payload = json.loads(ch._ws.send.call_args[0][0])
assert payload["type"] == "send_media"
assert payload["mimetype"] == "application/pdf"
@pytest.mark.asyncio
async def test_send_multiple_media():
ch = _make_channel()
msg = OutboundMessage(
channel="whatsapp",
chat_id="123@s.whatsapp.net",
content="",
media=["/tmp/a.png", "/tmp/b.mp4"],
)
await ch.send(msg)
assert ch._ws.send.call_count == 2
p1 = json.loads(ch._ws.send.call_args_list[0][0][0])
p2 = json.loads(ch._ws.send.call_args_list[1][0][0])
assert p1["mimetype"] == "image/png"
assert p2["mimetype"] == "video/mp4"
@pytest.mark.asyncio
async def test_send_when_disconnected_is_noop():
ch = _make_channel()
ch._connected = False
msg = OutboundMessage(
channel="whatsapp",
chat_id="123@s.whatsapp.net",
content="hello",
media=["/tmp/x.jpg"],
)
await ch.send(msg)
ch._ws.send.assert_not_called()
@pytest.mark.asyncio
async def test_group_policy_mention_skips_unmentioned_group_message():
ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock())
ch._handle_message = AsyncMock()
await ch._handle_bridge_message(
json.dumps(
{
"type": "message",
"id": "m1",
"sender": "12345@g.us",
"pn": "user@s.whatsapp.net",
"content": "hello group",
"timestamp": 1,
"isGroup": True,
"wasMentioned": False,
}
)
)
ch._handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_mentioned_group_message():
ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock())
ch._handle_message = AsyncMock()
await ch._handle_bridge_message(
json.dumps(
{
"type": "message",
"id": "m1",
"sender": "12345@g.us",
"pn": "user@s.whatsapp.net",
"content": "hello @bot",
"timestamp": 1,
"isGroup": True,
"wasMentioned": True,
}
)
)
ch._handle_message.assert_awaited_once()
kwargs = ch._handle_message.await_args.kwargs
assert kwargs["chat_id"] == "12345@g.us"
assert kwargs["sender_id"] == "user"
+147
View File
@@ -0,0 +1,147 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
from prompt_toolkit.formatted_text import HTML
from nanobot.cli import commands
from nanobot.cli import stream as stream_mod
@pytest.fixture
def mock_prompt_session():
"""Mock the global prompt session."""
mock_session = MagicMock()
mock_session.prompt_async = AsyncMock()
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session), \
patch("nanobot.cli.commands.patch_stdout"):
yield mock_session
@pytest.mark.asyncio
async def test_read_interactive_input_async_returns_input(mock_prompt_session):
"""Test that _read_interactive_input_async returns the user input from prompt_session."""
mock_prompt_session.prompt_async.return_value = "hello world"
result = await commands._read_interactive_input_async()
assert result == "hello world"
mock_prompt_session.prompt_async.assert_called_once()
args, _ = mock_prompt_session.prompt_async.call_args
assert isinstance(args[0], HTML) # Verify HTML prompt is used
@pytest.mark.asyncio
async def test_read_interactive_input_async_handles_eof(mock_prompt_session):
"""Test that EOFError converts to KeyboardInterrupt."""
mock_prompt_session.prompt_async.side_effect = EOFError()
with pytest.raises(KeyboardInterrupt):
await commands._read_interactive_input_async()
def test_init_prompt_session_creates_session():
"""Test that _init_prompt_session initializes the global session."""
# Ensure global is None before test
commands._PROMPT_SESSION = None
with patch("nanobot.cli.commands.PromptSession") as MockSession, \
patch("nanobot.cli.commands.FileHistory") as MockHistory, \
patch("pathlib.Path.home") as mock_home:
mock_home.return_value = MagicMock()
commands._init_prompt_session()
assert commands._PROMPT_SESSION is not None
MockSession.assert_called_once()
_, kwargs = MockSession.call_args
assert kwargs["multiline"] is False
assert kwargs["enable_open_in_editor"] is False
def test_thinking_spinner_pause_stops_and_restarts():
"""Pause should stop the active spinner and restart it afterward."""
spinner = MagicMock()
mock_console = MagicMock()
mock_console.status.return_value = spinner
thinking = stream_mod.ThinkingSpinner(console=mock_console)
with thinking:
with thinking.pause():
pass
assert spinner.method_calls == [
call.start(),
call.stop(),
call.start(),
call.stop(),
]
def test_print_cli_progress_line_pauses_spinner_before_printing():
"""CLI progress output should pause spinner to avoid garbled lines."""
order: list[str] = []
spinner = MagicMock()
spinner.start.side_effect = lambda: order.append("start")
spinner.stop.side_effect = lambda: order.append("stop")
mock_console = MagicMock()
mock_console.status.return_value = spinner
with patch.object(commands.console, "print", side_effect=lambda *_args, **_kwargs: order.append("print")):
thinking = stream_mod.ThinkingSpinner(console=mock_console)
with thinking:
commands._print_cli_progress_line("tool running", thinking)
assert order == ["start", "stop", "print", "start", "stop"]
@pytest.mark.asyncio
async def test_print_interactive_progress_line_pauses_spinner_before_printing():
"""Interactive progress output should also pause spinner cleanly."""
order: list[str] = []
spinner = MagicMock()
spinner.start.side_effect = lambda: order.append("start")
spinner.stop.side_effect = lambda: order.append("stop")
mock_console = MagicMock()
mock_console.status.return_value = spinner
async def fake_print(_text: str) -> None:
order.append("print")
with patch("nanobot.cli.commands._print_interactive_line", side_effect=fake_print):
thinking = stream_mod.ThinkingSpinner(console=mock_console)
with thinking:
await commands._print_interactive_progress_line("tool running", thinking)
assert order == ["start", "stop", "print", "start", "stop"]
def test_response_renderable_uses_text_for_explicit_plain_rendering():
status = (
"🐈 nanobot v0.1.4.post5\n"
"🧠 Model: MiniMax-M2.7\n"
"📊 Tokens: 20639 in / 29 out"
)
renderable = commands._response_renderable(
status,
render_markdown=True,
metadata={"render_as": "text"},
)
assert renderable.__class__.__name__ == "Text"
def test_response_renderable_preserves_normal_markdown_rendering():
renderable = commands._response_renderable("**bold**", render_markdown=True)
assert renderable.__class__.__name__ == "Markdown"
def test_response_renderable_without_metadata_keeps_markdown_path():
help_text = "🐈 nanobot commands:\n/status — Show bot status\n/help — Show available commands"
renderable = commands._response_renderable(help_text, render_markdown=True)
assert renderable.__class__.__name__ == "Markdown"
+905
View File
@@ -0,0 +1,905 @@
import json
import re
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from typer.testing import CliRunner
from nanobot.bus.events import OutboundMessage
from nanobot.cli.commands import _make_provider, app
from nanobot.config.schema import Config
from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_name
runner = CliRunner()
class _StopGatewayError(RuntimeError):
pass
import shutil
import pytest
@pytest.fixture
def mock_paths():
"""Mock config/workspace paths for test isolation."""
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
patch("nanobot.config.loader.save_config") as mock_sc, \
patch("nanobot.config.loader.load_config") as mock_lc, \
patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
base_dir = Path("./test_onboard_data")
if base_dir.exists():
shutil.rmtree(base_dir)
base_dir.mkdir()
config_file = base_dir / "config.json"
workspace_dir = base_dir / "workspace"
mock_cp.return_value = config_file
mock_ws.return_value = workspace_dir
mock_lc.side_effect = lambda _config_path=None: Config()
def _save_config(config: Config, config_path: Path | None = None):
target = config_path or config_file
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text(json.dumps(config.model_dump(by_alias=True)), encoding="utf-8")
mock_sc.side_effect = _save_config
yield config_file, workspace_dir, mock_ws
if base_dir.exists():
shutil.rmtree(base_dir)
def test_onboard_fresh_install(mock_paths):
"""No existing config — should create from scratch."""
config_file, workspace_dir, mock_ws = mock_paths
result = runner.invoke(app, ["onboard"])
assert result.exit_code == 0
assert "Created config" in result.stdout
assert "Created workspace" in result.stdout
assert "nanobot is ready" in result.stdout
assert config_file.exists()
assert (workspace_dir / "AGENTS.md").exists()
assert (workspace_dir / "memory" / "MEMORY.md").exists()
expected_workspace = Config().workspace_path
assert mock_ws.call_args.args == (expected_workspace,)
def test_onboard_existing_config_refresh(mock_paths):
"""Config exists, user declines overwrite — should refresh (load-merge-save)."""
config_file, workspace_dir, _ = mock_paths
config_file.write_text('{"existing": true}')
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
assert "Config already exists" in result.stdout
assert "existing values preserved" in result.stdout
assert workspace_dir.exists()
assert (workspace_dir / "AGENTS.md").exists()
def test_onboard_existing_config_overwrite(mock_paths):
"""Config exists, user confirms overwrite — should reset to defaults."""
config_file, workspace_dir, _ = mock_paths
config_file.write_text('{"existing": true}')
result = runner.invoke(app, ["onboard"], input="y\n")
assert result.exit_code == 0
assert "Config already exists" in result.stdout
assert "Config reset to defaults" in result.stdout
assert workspace_dir.exists()
def test_onboard_existing_workspace_safe_create(mock_paths):
"""Workspace exists — should not recreate, but still add missing templates."""
config_file, workspace_dir, _ = mock_paths
workspace_dir.mkdir(parents=True)
config_file.write_text("{}")
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
assert "Created workspace" not in result.stdout
assert "Created AGENTS.md" in result.stdout
assert (workspace_dir / "AGENTS.md").exists()
def _strip_ansi(text):
"""Remove ANSI escape codes from text."""
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
return ansi_escape.sub('', text)
def test_onboard_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["onboard", "--help"])
assert result.exit_code == 0
stripped_output = _strip_ansi(result.stdout)
assert "--workspace" in stripped_output
assert "-w" in stripped_output
assert "--config" in stripped_output
assert "-c" in stripped_output
assert "--wizard" in stripped_output
assert "--dir" not in stripped_output
def test_onboard_interactive_discard_does_not_save_or_create_workspace(mock_paths, monkeypatch):
config_file, workspace_dir, _ = mock_paths
from nanobot.cli.onboard import OnboardResult
monkeypatch.setattr(
"nanobot.cli.onboard.run_onboard",
lambda initial_config: OnboardResult(config=initial_config, should_save=False),
)
result = runner.invoke(app, ["onboard", "--wizard"])
assert result.exit_code == 0
assert "No changes were saved" in result.stdout
assert not config_file.exists()
assert not workspace_dir.exists()
def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch):
config_path = tmp_path / "instance" / "config.json"
workspace_path = tmp_path / "workspace"
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
result = runner.invoke(
app,
["onboard", "--config", str(config_path), "--workspace", str(workspace_path)],
)
assert result.exit_code == 0
saved = Config.model_validate(json.loads(config_path.read_text(encoding="utf-8")))
assert saved.workspace_path == workspace_path
assert (workspace_path / "AGENTS.md").exists()
stripped_output = _strip_ansi(result.stdout)
compact_output = stripped_output.replace("\n", "")
resolved_config = str(config_path.resolve())
assert resolved_config in compact_output
assert f"--config {resolved_config}" in compact_output
def test_onboard_wizard_preserves_explicit_config_in_next_steps(tmp_path, monkeypatch):
config_path = tmp_path / "instance" / "config.json"
workspace_path = tmp_path / "workspace"
from nanobot.cli.onboard import OnboardResult
monkeypatch.setattr(
"nanobot.cli.onboard.run_onboard",
lambda initial_config: OnboardResult(config=initial_config, should_save=True),
)
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
result = runner.invoke(
app,
["onboard", "--wizard", "--config", str(config_path), "--workspace", str(workspace_path)],
)
assert result.exit_code == 0
stripped_output = _strip_ansi(result.stdout)
compact_output = stripped_output.replace("\n", "")
resolved_config = str(config_path.resolve())
assert f'nanobot agent -m "Hello!" --config {resolved_config}' in compact_output
assert f"nanobot gateway --config {resolved_config}" in compact_output
def test_config_matches_github_copilot_codex_with_hyphen_prefix():
config = Config()
config.agents.defaults.model = "github-copilot/gpt-5.3-codex"
assert config.get_provider_name() == "github_copilot"
def test_config_matches_openai_codex_with_hyphen_prefix():
config = Config()
config.agents.defaults.model = "openai-codex/gpt-5.1-codex"
assert config.get_provider_name() == "openai_codex"
def test_config_dump_excludes_oauth_provider_blocks():
config = Config()
providers = config.model_dump(by_alias=True)["providers"]
assert "openaiCodex" not in providers
assert "githubCopilot" not in providers
def test_config_matches_explicit_ollama_prefix_without_api_key():
config = Config()
config.agents.defaults.model = "ollama/llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434/v1"
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
config = Config()
config.agents.defaults.provider = "ollama"
config.agents.defaults.model = "llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434/v1"
def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan():
config = Config.model_validate(
{
"agents": {
"defaults": {
"provider": "volcengineCodingPlan",
"model": "doubao-1-5-pro",
}
},
"providers": {
"volcengineCodingPlan": {
"apiKey": "test-key",
}
},
}
)
assert config.get_provider_name() == "volcengine_coding_plan"
assert config.get_api_base() == "https://ark.cn-beijing.volces.com/api/coding/v3"
def test_find_by_name_accepts_camel_case_and_hyphen_aliases():
assert find_by_name("volcengineCodingPlan") is not None
assert find_by_name("volcengineCodingPlan").name == "volcengine_coding_plan"
assert find_by_name("github-copilot") is not None
assert find_by_name("github-copilot").name == "github_copilot"
def test_config_auto_detects_ollama_from_local_api_base():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {"ollama": {"apiBase": "http://localhost:11434/v1"}},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434/v1"
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
"ollama": {"apiBase": "http://localhost:11434/v1"},
},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434/v1"
def test_config_falls_back_to_vllm_when_ollama_not_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
},
}
)
assert config.get_provider_name() == "vllm"
assert config.get_api_base() == "http://localhost:8000"
def test_openai_compat_provider_passes_model_through():
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider(default_model="github-copilot/gpt-5.3-codex")
assert provider.get_default_model() == "github-copilot/gpt-5.3-codex"
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
def test_make_provider_passes_extra_headers_to_custom_provider():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "custom", "model": "gpt-4o-mini"}},
"providers": {
"custom": {
"apiKey": "test-key",
"apiBase": "https://example.com/v1",
"extraHeaders": {
"APP-Code": "demo-app",
"x-session-affinity": "sticky-session",
},
}
},
}
)
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai:
_make_provider(config)
kwargs = mock_async_openai.call_args.kwargs
assert kwargs["api_key"] == "test-key"
assert kwargs["base_url"] == "https://example.com/v1"
assert kwargs["default_headers"]["APP-Code"] == "demo-app"
assert kwargs["default_headers"]["x-session-affinity"] == "sticky-session"
@pytest.fixture
def mock_agent_runtime(tmp_path):
"""Mock agent command dependencies for focused CLI tests."""
config = Config()
config.agents.defaults.workspace = str(tmp_path / "default-workspace")
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
patch("nanobot.cli.commands._make_provider", return_value=object()), \
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
patch("nanobot.bus.queue.MessageBus"), \
patch("nanobot.cron.service.CronService"), \
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
agent_loop = MagicMock()
agent_loop.channels_config = None
agent_loop.process_direct = AsyncMock(
return_value=OutboundMessage(channel="cli", chat_id="direct", content="mock-response"),
)
agent_loop.close_mcp = AsyncMock(return_value=None)
mock_agent_loop_cls.return_value = agent_loop
yield {
"config": config,
"load_config": mock_load_config,
"sync_templates": mock_sync_templates,
"agent_loop_cls": mock_agent_loop_cls,
"agent_loop": agent_loop,
"print_response": mock_print_response,
}
def test_agent_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["agent", "--help"])
assert result.exit_code == 0
stripped_output = _strip_ansi(result.stdout)
assert "--workspace" in stripped_output
assert "-w" in stripped_output
assert "--config" in stripped_output
assert "-c" in stripped_output
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
result = runner.invoke(app, ["agent", "-m", "hello"])
assert result.exit_code == 0
assert mock_agent_runtime["load_config"].call_args.args == (None,)
assert mock_agent_runtime["sync_templates"].call_args.args == (
mock_agent_runtime["config"].workspace_path,
)
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == (
mock_agent_runtime["config"].workspace_path
)
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
mock_agent_runtime["print_response"].assert_called_once_with(
"mock-response", render_markdown=True, metadata={},
)
def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
config_path = tmp_path / "agent-config.json"
config_path.write_text("{}")
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_path)])
assert result.exit_code == 0
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
seen: dict[str, Path] = {}
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object())
class _FakeAgentLoop:
def __init__(self, *args, **kwargs) -> None:
pass
async def process_direct(self, *_args, **_kwargs):
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
async def close_mcp(self) -> None:
return None
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
assert result.exit_code == 0
assert seen["config_path"] == config_file.resolve()
def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "agent-workspace")
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
class _FakeCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
class _FakeAgentLoop:
def __init__(self, *args, **kwargs) -> None:
pass
async def process_direct(self, *_args, **_kwargs):
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
async def close_mcp(self) -> None:
return None
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
assert result.exit_code == 0
assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json"
def test_agent_workspace_override_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
legacy_file = legacy_dir / "jobs.json"
legacy_file.write_text('{"jobs": []}')
override = tmp_path / "override-workspace"
config = Config()
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
class _FakeCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
class _FakeAgentLoop:
def __init__(self, *args, **kwargs) -> None:
pass
async def process_direct(self, *_args, **_kwargs):
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
async def close_mcp(self) -> None:
return None
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
result = runner.invoke(
app,
["agent", "-m", "hello", "-c", str(config_file), "-w", str(override)],
)
assert result.exit_code == 0
assert seen["cron_store"] == override / "cron" / "jobs.json"
assert legacy_file.exists()
assert not (override / "cron" / "jobs.json").exists()
def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
legacy_file = legacy_dir / "jobs.json"
legacy_file.write_text('{"jobs": []}')
custom_workspace = tmp_path / "custom-workspace"
config = Config()
config.agents.defaults.workspace = str(custom_workspace)
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
class _FakeCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
class _FakeAgentLoop:
def __init__(self, *args, **kwargs) -> None:
pass
async def process_direct(self, *_args, **_kwargs):
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
async def close_mcp(self) -> None:
return None
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
assert result.exit_code == 0
assert seen["cron_store"] == custom_workspace / "cron" / "jobs.json"
assert legacy_file.exists()
assert not (custom_workspace / "cron" / "jobs.json").exists()
def test_agent_overrides_workspace_path(mock_agent_runtime):
workspace_path = Path("/tmp/agent-workspace")
result = runner.invoke(app, ["agent", "-m", "hello", "-w", str(workspace_path)])
assert result.exit_code == 0
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path):
config_path = tmp_path / "agent-config.json"
config_path.write_text("{}")
workspace_path = Path("/tmp/agent-workspace")
result = runner.invoke(
app,
["agent", "-m", "hello", "-c", str(config_path), "-w", str(workspace_path)],
)
assert result.exit_code == 0
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_agent_hints_about_deprecated_memory_window(mock_agent_runtime, tmp_path):
config_file = tmp_path / "config.json"
config_file.write_text(json.dumps({"agents": {"defaults": {"memoryWindow": 42}}}))
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
assert result.exit_code == 0
assert "memoryWindow" in result.stdout
assert "no longer used" in result.stdout
def test_heartbeat_retains_recent_messages_by_default():
config = Config()
assert config.gateway.heartbeat.keep_recent_messages == 8
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGatewayError)
assert seen["config_path"] == config_file.resolve()
assert seen["workspace"] == Path(config.agents.defaults.workspace)
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
override = tmp_path / "override-workspace"
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
)
result = runner.invoke(
app,
["gateway", "--config", str(config_file), "--workspace", str(override)],
)
assert isinstance(result.exception, _StopGatewayError)
assert seen["workspace"] == override
assert config.workspace_path == override
def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGatewayError)
assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json"
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
legacy_file = legacy_dir / "jobs.json"
legacy_file.write_text('{"jobs": []}')
override = tmp_path / "override-workspace"
config = Config()
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
result = runner.invoke(
app,
["gateway", "--config", str(config_file), "--workspace", str(override)],
)
assert isinstance(result.exception, _StopGatewayError)
assert seen["cron_store"] == override / "cron" / "jobs.json"
assert legacy_file.exists()
assert not (override / "cron" / "jobs.json").exists()
def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
legacy_file = legacy_dir / "jobs.json"
legacy_file.write_text('{"jobs": []}')
custom_workspace = tmp_path / "custom-workspace"
config = Config()
config.agents.defaults.workspace = str(custom_workspace)
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGatewayError)
assert seen["cron_store"] == custom_workspace / "cron" / "jobs.json"
assert legacy_file.exists()
assert not (custom_workspace / "cron" / "jobs.json").exists()
def test_migrate_cron_store_moves_legacy_file(tmp_path: Path) -> None:
"""Legacy global jobs.json is moved into the workspace on first run."""
from nanobot.cli.commands import _migrate_cron_store
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
legacy_file = legacy_dir / "jobs.json"
legacy_file.write_text('{"jobs": []}')
config = Config()
config.agents.defaults.workspace = str(tmp_path / "workspace")
workspace_cron = config.workspace_path / "cron" / "jobs.json"
with patch("nanobot.config.paths.get_cron_dir", return_value=legacy_dir):
_migrate_cron_store(config)
assert workspace_cron.exists()
assert workspace_cron.read_text() == '{"jobs": []}'
assert not legacy_file.exists()
def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) -> None:
"""Migration does not overwrite an existing workspace cron store."""
from nanobot.cli.commands import _migrate_cron_store
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
(legacy_dir / "jobs.json").write_text('{"old": true}')
config = Config()
config.agents.defaults.workspace = str(tmp_path / "workspace")
workspace_cron = config.workspace_path / "cron" / "jobs.json"
workspace_cron.parent.mkdir(parents=True)
workspace_cron.write_text('{"new": true}')
with patch("nanobot.config.paths.get_cron_dir", return_value=legacy_dir):
_migrate_cron_store(config)
assert workspace_cron.read_text() == '{"new": true}'
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGatewayError)
assert "port 18791" in result.stdout
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
assert isinstance(result.exception, _StopGatewayError)
assert "port 18792" in result.stdout
def test_channels_login_requires_channel_name() -> None:
result = runner.invoke(app, ["channels", "login"])
assert result.exit_code == 2
@@ -0,0 +1,190 @@
"""Tests for /restart slash command."""
from __future__ import annotations
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.providers.base import LLMResponse
def _make_loop():
"""Create a minimal AgentLoop with mocked dependencies."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
workspace = MagicMock()
workspace.__truediv__ = MagicMock(return_value=MagicMock())
with patch("nanobot.agent.loop.ContextBuilder"), \
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager"):
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
return loop, bus
class TestRestartCommand:
@pytest.mark.asyncio
async def test_restart_sends_message_and_calls_execv(self):
from nanobot.command.builtin import cmd_restart
from nanobot.command.router import CommandContext
loop, bus = _make_loop()
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="/restart")
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw="/restart", loop=loop)
with patch("nanobot.command.builtin.os.execv") as mock_execv:
out = await cmd_restart(ctx)
assert "Restarting" in out.content
await asyncio.sleep(1.5)
mock_execv.assert_called_once()
@pytest.mark.asyncio
async def test_restart_intercepted_in_run_loop(self):
"""Verify /restart is handled at the run-loop level, not inside _dispatch."""
loop, bus = _make_loop()
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/restart")
with patch.object(loop, "_dispatch", new_callable=AsyncMock) as mock_dispatch, \
patch("nanobot.command.builtin.os.execv"):
await bus.publish_inbound(msg)
loop._running = True
run_task = asyncio.create_task(loop.run())
await asyncio.sleep(0.1)
loop._running = False
run_task.cancel()
try:
await run_task
except asyncio.CancelledError:
pass
mock_dispatch.assert_not_called()
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert "Restarting" in out.content
@pytest.mark.asyncio
async def test_status_intercepted_in_run_loop(self):
"""Verify /status is handled at the run-loop level for immediate replies."""
loop, bus = _make_loop()
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
with patch.object(loop, "_dispatch", new_callable=AsyncMock) as mock_dispatch:
await bus.publish_inbound(msg)
loop._running = True
run_task = asyncio.create_task(loop.run())
await asyncio.sleep(0.1)
loop._running = False
run_task.cancel()
try:
await run_task
except asyncio.CancelledError:
pass
mock_dispatch.assert_not_called()
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert "nanobot" in out.content.lower() or "Model" in out.content
@pytest.mark.asyncio
async def test_run_propagates_external_cancellation(self):
"""External task cancellation should not be swallowed by the inbound wait loop."""
loop, _bus = _make_loop()
run_task = asyncio.create_task(loop.run())
await asyncio.sleep(0.1)
run_task.cancel()
with pytest.raises(asyncio.CancelledError):
await asyncio.wait_for(run_task, timeout=1.0)
@pytest.mark.asyncio
async def test_help_includes_restart(self):
loop, bus = _make_loop()
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/help")
response = await loop._process_message(msg)
assert response is not None
assert "/restart" in response.content
assert "/status" in response.content
assert response.metadata == {"render_as": "text"}
@pytest.mark.asyncio
async def test_status_reports_runtime_info(self):
loop, _bus = _make_loop()
session = MagicMock()
session.get_history.return_value = [{"role": "user"}] * 3
loop.sessions.get_or_create.return_value = session
loop._start_time = time.time() - 125
loop._last_usage = {"prompt_tokens": 0, "completion_tokens": 0}
loop.memory_consolidator.estimate_session_prompt_tokens = MagicMock(
return_value=(20500, "tiktoken")
)
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
response = await loop._process_message(msg)
assert response is not None
assert "Model: test-model" in response.content
assert "Tokens: 0 in / 0 out" in response.content
assert "Context: 20k/64k (31%)" in response.content
assert "Session: 3 messages" in response.content
assert "Uptime: 2m 5s" in response.content
assert response.metadata == {"render_as": "text"}
@pytest.mark.asyncio
async def test_run_agent_loop_resets_usage_when_provider_omits_it(self):
loop, _bus = _make_loop()
loop.provider.chat_with_retry = AsyncMock(side_effect=[
LLMResponse(content="first", usage={"prompt_tokens": 9, "completion_tokens": 4}),
LLMResponse(content="second", usage={}),
])
await loop._run_agent_loop([])
assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4}
await loop._run_agent_loop([])
assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0}
@pytest.mark.asyncio
async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self):
loop, _bus = _make_loop()
session = MagicMock()
session.get_history.return_value = [{"role": "user"}]
loop.sessions.get_or_create.return_value = session
loop._last_usage = {"prompt_tokens": 1200, "completion_tokens": 34}
loop.memory_consolidator.estimate_session_prompt_tokens = MagicMock(
return_value=(0, "none")
)
response = await loop._process_message(
InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
)
assert response is not None
assert "Tokens: 1200 in / 34 out" in response.content
assert "Context: 1k/64k (1%)" in response.content
@pytest.mark.asyncio
async def test_process_direct_preserves_render_metadata(self):
loop, _bus = _make_loop()
session = MagicMock()
session.get_history.return_value = []
loop.sessions.get_or_create.return_value = session
loop.subagents.get_running_count.return_value = 0
response = await loop.process_direct("/status", session_key="cli:test")
assert response is not None
assert response.metadata == {"render_as": "text"}
@@ -0,0 +1,128 @@
import json
from nanobot.config.loader import load_config, save_config
def test_load_config_keeps_max_tokens_and_ignores_legacy_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 1234,
"memoryWindow": 42,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
assert config.agents.defaults.max_tokens == 1234
assert config.agents.defaults.context_window_tokens == 65_536
assert not hasattr(config.agents.defaults, "memory_window")
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 2222,
"memoryWindow": 30,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 2222
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults
def test_onboard_does_not_crash_with_legacy_memory_window(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 3333,
"memoryWindow": 50,
}
}
}
),
encoding="utf-8",
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
from typer.testing import CliRunner
from nanobot.cli.commands import app
runner = CliRunner()
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None:
from types import SimpleNamespace
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
json.dumps(
{
"channels": {
"qq": {
"enabled": False,
"appId": "",
"secret": "",
"allowFrom": [],
}
}
}
),
encoding="utf-8",
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {
"qq": SimpleNamespace(
default_config=lambda: {
"enabled": False,
"appId": "",
"secret": "",
"allowFrom": [],
"msgFormat": "plain",
}
)
},
)
from typer.testing import CliRunner
from nanobot.cli.commands import app
runner = CliRunner()
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["channels"]["qq"]["msgFormat"] == "plain"
@@ -0,0 +1,49 @@
from pathlib import Path
from nanobot.config.paths import (
get_bridge_install_dir,
get_cli_history_path,
get_cron_dir,
get_data_dir,
get_legacy_sessions_dir,
get_logs_dir,
get_media_dir,
get_runtime_subdir,
get_workspace_path,
is_default_workspace,
)
def test_runtime_dirs_follow_config_path(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance-a" / "config.json"
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
assert get_data_dir() == config_file.parent
assert get_runtime_subdir("cron") == config_file.parent / "cron"
assert get_cron_dir() == config_file.parent / "cron"
assert get_logs_dir() == config_file.parent / "logs"
def test_media_dir_supports_channel_namespace(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance-b" / "config.json"
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
assert get_media_dir() == config_file.parent / "media"
assert get_media_dir("telegram") == config_file.parent / "media" / "telegram"
def test_shared_and_legacy_paths_remain_global() -> None:
assert get_cli_history_path() == Path.home() / ".nanobot" / "history" / "cli_history"
assert get_bridge_install_dir() == Path.home() / ".nanobot" / "bridge"
assert get_legacy_sessions_dir() == Path.home() / ".nanobot" / "sessions"
def test_workspace_path_is_explicitly_resolved() -> None:
assert get_workspace_path() == Path.home() / ".nanobot" / "workspace"
assert get_workspace_path("~/custom-workspace") == Path.home() / "custom-workspace"
def test_is_default_workspace_distinguishes_default_and_custom_paths() -> None:
assert is_default_workspace(None) is True
assert is_default_workspace(Path.home() / ".nanobot" / "workspace") is True
assert is_default_workspace("~/custom-workspace") is False
+143
View File
@@ -0,0 +1,143 @@
import asyncio
import json
import pytest
from nanobot.cron.service import CronService
from nanobot.cron.types import CronSchedule
def test_add_job_rejects_unknown_timezone(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
with pytest.raises(ValueError, match="unknown timezone 'America/Vancovuer'"):
service.add_job(
name="tz typo",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancovuer"),
message="hello",
)
assert service.list_jobs(include_disabled=True) == []
def test_add_job_accepts_valid_timezone(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
job = service.add_job(
name="tz ok",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancouver"),
message="hello",
)
assert job.schedule.tz == "America/Vancouver"
assert job.state.next_run_at_ms is not None
@pytest.mark.asyncio
async def test_execute_job_records_run_history(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
job = service.add_job(
name="hist",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
await service.run_job(job.id)
loaded = service.get_job(job.id)
assert loaded is not None
assert len(loaded.state.run_history) == 1
rec = loaded.state.run_history[0]
assert rec.status == "ok"
assert rec.duration_ms >= 0
assert rec.error is None
@pytest.mark.asyncio
async def test_run_history_records_errors(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
async def fail(_):
raise RuntimeError("boom")
service = CronService(store_path, on_job=fail)
job = service.add_job(
name="fail",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
await service.run_job(job.id)
loaded = service.get_job(job.id)
assert len(loaded.state.run_history) == 1
assert loaded.state.run_history[0].status == "error"
assert loaded.state.run_history[0].error == "boom"
@pytest.mark.asyncio
async def test_run_history_trimmed_to_max(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
job = service.add_job(
name="trim",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
for _ in range(25):
await service.run_job(job.id)
loaded = service.get_job(job.id)
assert len(loaded.state.run_history) == CronService._MAX_RUN_HISTORY
@pytest.mark.asyncio
async def test_run_history_persisted_to_disk(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
job = service.add_job(
name="persist",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
await service.run_job(job.id)
raw = json.loads(store_path.read_text())
history = raw["jobs"][0]["state"]["runHistory"]
assert len(history) == 1
assert history[0]["status"] == "ok"
assert "runAtMs" in history[0]
assert "durationMs" in history[0]
fresh = CronService(store_path)
loaded = fresh.get_job(job.id)
assert len(loaded.state.run_history) == 1
assert loaded.state.run_history[0].status == "ok"
@pytest.mark.asyncio
async def test_running_service_honors_external_disable(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
called: list[str] = []
async def on_job(job) -> None:
called.append(job.id)
service = CronService(store_path, on_job=on_job)
job = service.add_job(
name="external-disable",
schedule=CronSchedule(kind="every", every_ms=200),
message="hello",
)
await service.start()
try:
# Wait slightly to ensure file mtime is definitively different
await asyncio.sleep(0.05)
external = CronService(store_path)
updated = external.enable_job(job.id, enabled=False)
assert updated is not None
assert updated.enabled is False
await asyncio.sleep(0.35)
assert called == []
finally:
service.stop()
@@ -0,0 +1,299 @@
"""Tests for CronTool._list_jobs() output formatting."""
from datetime import datetime, timezone
from nanobot.agent.tools.cron import CronTool
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJobState, CronSchedule
def _make_tool(tmp_path) -> CronTool:
service = CronService(tmp_path / "cron" / "jobs.json")
return CronTool(service)
def _make_tool_with_tz(tmp_path, tz: str) -> CronTool:
service = CronService(tmp_path / "cron" / "jobs.json")
return CronTool(service, default_timezone=tz)
# -- _format_timing tests --
def test_format_timing_cron_with_tz(tmp_path) -> None:
tool = _make_tool(tmp_path)
s = CronSchedule(kind="cron", expr="0 9 * * 1-5", tz="America/Denver")
assert tool._format_timing(s) == "cron: 0 9 * * 1-5 (America/Denver)"
def test_format_timing_cron_without_tz(tmp_path) -> None:
tool = _make_tool(tmp_path)
s = CronSchedule(kind="cron", expr="*/5 * * * *")
assert tool._format_timing(s) == "cron: */5 * * * *"
def test_format_timing_every_hours(tmp_path) -> None:
tool = _make_tool(tmp_path)
s = CronSchedule(kind="every", every_ms=7_200_000)
assert tool._format_timing(s) == "every 2h"
def test_format_timing_every_minutes(tmp_path) -> None:
tool = _make_tool(tmp_path)
s = CronSchedule(kind="every", every_ms=1_800_000)
assert tool._format_timing(s) == "every 30m"
def test_format_timing_every_seconds(tmp_path) -> None:
tool = _make_tool(tmp_path)
s = CronSchedule(kind="every", every_ms=30_000)
assert tool._format_timing(s) == "every 30s"
def test_format_timing_every_non_minute_seconds(tmp_path) -> None:
tool = _make_tool(tmp_path)
s = CronSchedule(kind="every", every_ms=90_000)
assert tool._format_timing(s) == "every 90s"
def test_format_timing_every_milliseconds(tmp_path) -> None:
tool = _make_tool(tmp_path)
s = CronSchedule(kind="every", every_ms=200)
assert tool._format_timing(s) == "every 200ms"
def test_format_timing_at(tmp_path) -> None:
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
s = CronSchedule(kind="at", at_ms=1773684000000)
result = tool._format_timing(s)
assert "Asia/Shanghai" in result
assert result.startswith("at 2026-")
def test_format_timing_fallback(tmp_path) -> None:
tool = _make_tool(tmp_path)
s = CronSchedule(kind="every") # no every_ms
assert tool._format_timing(s) == "every"
# -- _format_state tests --
def test_format_state_empty(tmp_path) -> None:
tool = _make_tool(tmp_path)
state = CronJobState()
assert tool._format_state(state, CronSchedule(kind="every")) == []
def test_format_state_last_run_ok(tmp_path) -> None:
tool = _make_tool(tmp_path)
state = CronJobState(last_run_at_ms=1773673200000, last_status="ok")
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
assert len(lines) == 1
assert "Last run:" in lines[0]
assert "ok" in lines[0]
def test_format_state_last_run_with_error(tmp_path) -> None:
tool = _make_tool(tmp_path)
state = CronJobState(last_run_at_ms=1773673200000, last_status="error", last_error="timeout")
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
assert len(lines) == 1
assert "error" in lines[0]
assert "timeout" in lines[0]
def test_format_state_next_run_only(tmp_path) -> None:
tool = _make_tool(tmp_path)
state = CronJobState(next_run_at_ms=1773684000000)
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
assert len(lines) == 1
assert "Next run:" in lines[0]
def test_format_state_both(tmp_path) -> None:
tool = _make_tool(tmp_path)
state = CronJobState(
last_run_at_ms=1773673200000, last_status="ok", next_run_at_ms=1773684000000
)
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
assert len(lines) == 2
assert "Last run:" in lines[0]
assert "Next run:" in lines[1]
def test_format_state_unknown_status(tmp_path) -> None:
tool = _make_tool(tmp_path)
state = CronJobState(last_run_at_ms=1773673200000, last_status=None)
lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"))
assert "unknown" in lines[0]
# -- _list_jobs integration tests --
def test_list_empty(tmp_path) -> None:
tool = _make_tool(tmp_path)
assert tool._list_jobs() == "No scheduled jobs."
def test_list_cron_job_shows_expression_and_timezone(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool._cron.add_job(
name="Morning scan",
schedule=CronSchedule(kind="cron", expr="0 9 * * 1-5", tz="America/Denver"),
message="scan",
)
result = tool._list_jobs()
assert "cron: 0 9 * * 1-5 (America/Denver)" in result
def test_list_every_job_shows_human_interval(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool._cron.add_job(
name="Frequent check",
schedule=CronSchedule(kind="every", every_ms=1_800_000),
message="check",
)
result = tool._list_jobs()
assert "every 30m" in result
def test_list_every_job_hours(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool._cron.add_job(
name="Hourly check",
schedule=CronSchedule(kind="every", every_ms=7_200_000),
message="check",
)
result = tool._list_jobs()
assert "every 2h" in result
def test_list_every_job_seconds(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool._cron.add_job(
name="Fast check",
schedule=CronSchedule(kind="every", every_ms=30_000),
message="check",
)
result = tool._list_jobs()
assert "every 30s" in result
def test_list_every_job_non_minute_seconds(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool._cron.add_job(
name="Ninety-second check",
schedule=CronSchedule(kind="every", every_ms=90_000),
message="check",
)
result = tool._list_jobs()
assert "every 90s" in result
def test_list_every_job_milliseconds(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool._cron.add_job(
name="Sub-second check",
schedule=CronSchedule(kind="every", every_ms=200),
message="check",
)
result = tool._list_jobs()
assert "every 200ms" in result
def test_list_at_job_shows_iso_timestamp(tmp_path) -> None:
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
tool._cron.add_job(
name="One-shot",
schedule=CronSchedule(kind="at", at_ms=1773684000000),
message="fire",
)
result = tool._list_jobs()
assert "at 2026-" in result
assert "Asia/Shanghai" in result
def test_list_shows_last_run_state(tmp_path) -> None:
tool = _make_tool(tmp_path)
job = tool._cron.add_job(
name="Stateful job",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
message="test",
)
# Simulate a completed run by updating state in the store
job.state.last_run_at_ms = 1773673200000
job.state.last_status = "ok"
tool._cron._save_store()
result = tool._list_jobs()
assert "Last run:" in result
assert "ok" in result
assert "(UTC)" in result
def test_list_shows_error_message(tmp_path) -> None:
tool = _make_tool(tmp_path)
job = tool._cron.add_job(
name="Failed job",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
message="test",
)
job.state.last_run_at_ms = 1773673200000
job.state.last_status = "error"
job.state.last_error = "timeout"
tool._cron._save_store()
result = tool._list_jobs()
assert "error" in result
assert "timeout" in result
def test_list_shows_next_run(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool._cron.add_job(
name="Upcoming job",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
message="test",
)
result = tool._list_jobs()
assert "Next run:" in result
assert "(UTC)" in result
def test_add_cron_job_defaults_to_tool_timezone(tmp_path) -> None:
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
tool.set_context("telegram", "chat-1")
result = tool._add_job("Morning standup", None, "0 8 * * *", None, None)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
assert job.schedule.tz == "Asia/Shanghai"
def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None:
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
tool.set_context("telegram", "chat-1")
result = tool._add_job("Morning reminder", None, None, None, "2026-03-25T08:00:00")
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
expected = int(datetime(2026, 3, 25, 0, 0, 0, tzinfo=timezone.utc).timestamp() * 1000)
assert job.schedule.at_ms == expected
def test_list_excludes_disabled_jobs(tmp_path) -> None:
tool = _make_tool(tmp_path)
job = tool._cron.add_job(
name="Paused job",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
message="test",
)
tool._cron.enable_job(job.id, enabled=False)
result = tool._list_jobs()
assert "Paused job" not in result
assert result == "No scheduled jobs."
@@ -0,0 +1,399 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse
def test_azure_openai_provider_init():
"""Test AzureOpenAIProvider initialization without deployment_name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
assert provider.api_key == "test-key"
assert provider.api_base == "https://test-resource.openai.azure.com/"
assert provider.default_model == "gpt-4o-deployment"
assert provider.api_version == "2024-10-21"
def test_azure_openai_provider_init_validation():
"""Test AzureOpenAIProvider initialization validation."""
# Missing api_key
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
AzureOpenAIProvider(api_key="", api_base="https://test.com")
# Missing api_base
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
AzureOpenAIProvider(api_key="test", api_base="")
def test_build_chat_url():
"""Test Azure OpenAI URL building with different deployment names."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test various deployment names
test_cases = [
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
]
for deployment_name, expected_url in test_cases:
url = provider._build_chat_url(deployment_name)
assert url == expected_url
def test_build_chat_url_api_base_without_slash():
"""Test URL building when api_base doesn't end with slash."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com", # No trailing slash
default_model="gpt-4o",
)
url = provider._build_chat_url("test-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
def test_build_headers():
"""Test Azure OpenAI header building with api-key authentication."""
provider = AzureOpenAIProvider(
api_key="test-api-key-123",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
headers = provider._build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
assert "x-session-affinity" in headers
def test_prepare_request_payload():
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert payload["temperature"] == 0.8
assert "tools" not in payload
# Test with tools
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
assert payload_with_tools["tools"] == tools
assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
)
assert payload_with_reasoning["reasoning_effort"] == "medium"
assert "temperature" not in payload_with_reasoning
def test_prepare_request_payload_sanitizes_messages():
"""Test Azure payload strips non-standard message keys before sending."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages)
assert payload["messages"] == [
{
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
@pytest.mark.asyncio
async def test_chat_success():
"""Test successful chat request using model as deployment name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
# Mock response data
mock_response_data = {
"choices": [{
"message": {
"content": "Hello! How can I help you today?",
"role": "assistant"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 18,
"total_tokens": 30
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
# Test with specific model (deployment name)
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages, model="custom-deployment")
assert isinstance(result, LLMResponse)
assert result.content == "Hello! How can I help you today?"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 12
assert result.usage["completion_tokens"] == 18
assert result.usage["total_tokens"] == 30
# Verify URL was built with the provided model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_uses_default_model_when_no_model_provided():
"""Test that chat uses default_model when no model is specified."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="default-deployment",
)
mock_response_data = {
"choices": [{
"message": {"content": "Response", "role": "assistant"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Test"}]
await provider.chat(messages) # No model specified
# Verify URL was built with default model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_with_tool_calls():
"""Test chat request with tool calls in response."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Mock response with tool calls
mock_response_data = {
"choices": [{
"message": {
"content": None,
"role": "assistant",
"tool_calls": [{
"id": "call_12345",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}'
}
}]
},
"finish_reason": "tool_calls"
}],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
result = await provider.chat(messages, tools=tools, model="weather-model")
assert isinstance(result, LLMResponse)
assert result.content is None
assert result.finish_reason == "tool_calls"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
@pytest.mark.asyncio
async def test_chat_api_error():
"""Test chat request API error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.text = "Invalid authentication credentials"
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Azure OpenAI API Error 401" in result.content
assert "Invalid authentication credentials" in result.content
assert result.finish_reason == "error"
@pytest.mark.asyncio
async def test_chat_connection_error():
"""Test chat request connection error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_context = AsyncMock()
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
assert result.finish_reason == "error"
def test_parse_response_malformed():
"""Test response parsing with malformed data."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test with missing choices
malformed_response = {"usage": {"prompt_tokens": 10}}
result = provider._parse_response(malformed_response)
assert isinstance(result, LLMResponse)
assert "Error parsing Azure OpenAI response" in result.content
assert result.finish_reason == "error"
def test_get_default_model():
"""Test get_default_model method."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="my-custom-deployment",
)
assert provider.get_default_model() == "my-custom-deployment"
if __name__ == "__main__":
# Run basic tests
print("Running basic Azure OpenAI provider tests...")
# Test initialization
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
print("✅ Provider initialization successful")
# Test URL building
url = provider._build_chat_url("my-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
print("✅ URL building works correctly")
# Test headers
headers = provider._build_headers()
assert headers["api-key"] == "test-key"
assert headers["Content-Type"] == "application/json"
print("✅ Header building works correctly")
# Test payload preparation
messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly")
print("✅ All basic tests passed! Updated test file is working correctly.")
@@ -0,0 +1,55 @@
"""Tests for OpenAICompatProvider handling custom/direct endpoints."""
from types import SimpleNamespace
from unittest.mock import patch
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
def test_custom_provider_parse_handles_empty_choices() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
response = SimpleNamespace(choices=[])
result = provider._parse(response)
assert result.finish_reason == "error"
assert "empty choices" in result.content
def test_custom_provider_parse_accepts_plain_string_response() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
result = provider._parse("hello from backend")
assert result.finish_reason == "stop"
assert result.content == "hello from backend"
def test_custom_provider_parse_accepts_dict_response() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
result = provider._parse({
"choices": [{
"message": {"content": "hello from dict"},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
},
})
assert result.finish_reason == "stop"
assert result.content == "hello from dict"
assert result.usage["total_tokens"] == 3
def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
result = OpenAICompatProvider._parse_chunks(["hello ", "world"])
assert result.finish_reason == "stop"
assert result.content == "hello world"
@@ -0,0 +1,216 @@
"""Tests for OpenAICompatProvider spec-driven behavior.
Validates that:
- OpenRouter (no strip) keeps model names intact.
- AiHubMix (strip_model_prefix=True) strips provider prefixes.
- Standard providers pass model names through as-is.
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
from nanobot.providers.registry import find_by_name
def _fake_chat_response(content: str = "ok") -> SimpleNamespace:
"""Build a minimal OpenAI chat completion response."""
message = SimpleNamespace(
content=content,
tool_calls=None,
reasoning_content=None,
)
choice = SimpleNamespace(message=message, finish_reason="stop")
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
return SimpleNamespace(choices=[choice], usage=usage)
def _fake_tool_call_response() -> SimpleNamespace:
"""Build a minimal chat response that includes Gemini-style extra_content."""
function = SimpleNamespace(
name="exec",
arguments='{"cmd":"ls"}',
provider_specific_fields={"inner": "value"},
)
tool_call = SimpleNamespace(
id="call_123",
index=0,
type="function",
function=function,
extra_content={"google": {"thought_signature": "signed-token"}},
)
message = SimpleNamespace(
content=None,
tool_calls=[tool_call],
reasoning_content=None,
)
choice = SimpleNamespace(message=message, finish_reason="tool_calls")
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
return SimpleNamespace(choices=[choice], usage=usage)
def test_openrouter_spec_is_gateway() -> None:
spec = find_by_name("openrouter")
assert spec is not None
assert spec.is_gateway is True
assert spec.default_api_base == "https://openrouter.ai/api/v1"
def test_openrouter_sets_default_attribution_headers() -> None:
spec = find_by_name("openrouter")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
OpenAICompatProvider(
api_key="sk-or-test-key",
api_base="https://openrouter.ai/api/v1",
default_model="anthropic/claude-sonnet-4-5",
spec=spec,
)
headers = MockClient.call_args.kwargs["default_headers"]
assert headers["HTTP-Referer"] == "https://github.com/HKUDS/nanobot"
assert headers["X-OpenRouter-Title"] == "nanobot"
assert headers["X-OpenRouter-Categories"] == "cli-agent,personal-agent"
assert "x-session-affinity" in headers
def test_openrouter_user_headers_override_default_attribution() -> None:
spec = find_by_name("openrouter")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
OpenAICompatProvider(
api_key="sk-or-test-key",
api_base="https://openrouter.ai/api/v1",
default_model="anthropic/claude-sonnet-4-5",
extra_headers={
"HTTP-Referer": "https://nanobot.ai",
"X-OpenRouter-Title": "Nanobot Pro",
"X-Custom-App": "enabled",
},
spec=spec,
)
headers = MockClient.call_args.kwargs["default_headers"]
assert headers["HTTP-Referer"] == "https://nanobot.ai"
assert headers["X-OpenRouter-Title"] == "Nanobot Pro"
assert headers["X-OpenRouter-Categories"] == "cli-agent,personal-agent"
assert headers["X-Custom-App"] == "enabled"
@pytest.mark.asyncio
async def test_openrouter_keeps_model_name_intact() -> None:
"""OpenRouter gateway keeps the full model name (gateway does its own routing)."""
mock_create = AsyncMock(return_value=_fake_chat_response())
spec = find_by_name("openrouter")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="sk-or-test-key",
api_base="https://openrouter.ai/api/v1",
default_model="anthropic/claude-sonnet-4-5",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="anthropic/claude-sonnet-4-5",
)
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs["model"] == "anthropic/claude-sonnet-4-5"
@pytest.mark.asyncio
async def test_aihubmix_strips_model_prefix() -> None:
"""AiHubMix strips the provider prefix (strip_model_prefix=True)."""
mock_create = AsyncMock(return_value=_fake_chat_response())
spec = find_by_name("aihubmix")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="sk-aihub-test-key",
api_base="https://aihubmix.com/v1",
default_model="claude-sonnet-4-5",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="anthropic/claude-sonnet-4-5",
)
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs["model"] == "claude-sonnet-4-5"
@pytest.mark.asyncio
async def test_standard_provider_passes_model_through() -> None:
"""Standard provider (e.g. deepseek) passes model name through as-is."""
mock_create = AsyncMock(return_value=_fake_chat_response())
spec = find_by_name("deepseek")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="sk-deepseek-test-key",
default_model="deepseek-chat",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="deepseek-chat",
)
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs["model"] == "deepseek-chat"
@pytest.mark.asyncio
async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None:
"""Gemini extra_content (thought signatures) must survive parse→serialize round-trip."""
mock_create = AsyncMock(return_value=_fake_tool_call_response())
spec = find_by_name("gemini")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="test-key",
api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
default_model="google/gemini-3.1-pro-preview",
spec=spec,
)
result = await provider.chat(
messages=[{"role": "user", "content": "run exec"}],
model="google/gemini-3.1-pro-preview",
)
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}}
assert tool_call.function_provider_specific_fields == {"inner": "value"}
serialized = tool_call.to_openai_tool_call()
assert serialized["extra_content"] == {"google": {"thought_signature": "signed-token"}}
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
def test_openai_model_passthrough() -> None:
"""OpenAI models pass through unchanged."""
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-4o",
spec=spec,
)
assert provider.get_default_model() == "gpt-4o"
@@ -0,0 +1,20 @@
"""Tests for the Mistral provider registration."""
from nanobot.config.schema import ProvidersConfig
from nanobot.providers.registry import PROVIDERS
def test_mistral_config_field_exists():
"""ProvidersConfig should have a mistral field."""
config = ProvidersConfig()
assert hasattr(config, "mistral")
def test_mistral_provider_in_registry():
"""Mistral should be registered in the provider registry."""
specs = {s.name: s for s in PROVIDERS}
assert "mistral" in specs
mistral = specs["mistral"]
assert mistral.env_key == "MISTRAL_API_KEY"
assert mistral.default_api_base == "https://api.mistral.ai/v1"
@@ -0,0 +1,213 @@
import asyncio
import pytest
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
class ScriptedProvider(LLMProvider):
def __init__(self, responses):
super().__init__()
self._responses = list(responses)
self.calls = 0
self.last_kwargs: dict = {}
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
self.last_kwargs = kwargs
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
return response
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(content="ok"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.finish_reason == "stop"
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="401 unauthorized", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "401 unauthorized"
assert provider.calls == 1
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit a", finish_reason="error"),
LLMResponse(content="429 rate limit b", finish_reason="error"),
LLMResponse(content="429 rate limit c", finish_reason="error"),
LLMResponse(content="503 final server error", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "503 final server error"
assert provider.calls == 4
assert delays == [1, 2, 4]
@pytest.mark.asyncio
async def test_chat_with_retry_preserves_cancelled_error() -> None:
provider = ScriptedProvider([asyncio.CancelledError()])
with pytest.raises(asyncio.CancelledError):
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
@pytest.mark.asyncio
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
"""When callers omit generation params, provider.generation defaults are used."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert provider.last_kwargs["temperature"] == 0.2
assert provider.last_kwargs["max_tokens"] == 321
assert provider.last_kwargs["reasoning_effort"] == "high"
@pytest.mark.asyncio
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
"""Explicit kwargs should override provider.generation defaults."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
temperature=0.9,
max_tokens=9999,
reasoning_effort="low",
)
assert provider.last_kwargs["temperature"] == 0.9
assert provider.last_kwargs["max_tokens"] == 9999
assert provider.last_kwargs["reasoning_effort"] == "low"
# ---------------------------------------------------------------------------
# Image fallback tests
# ---------------------------------------------------------------------------
_IMAGE_MSG = [
{"role": "user", "content": [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}, "_meta": {"path": "/media/test.png"}},
]},
]
_IMAGE_MSG_NO_META = [
{"role": "user", "content": [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]},
]
@pytest.mark.asyncio
async def test_non_transient_error_with_images_retries_without_images() -> None:
"""Any non-transient error retries once with images stripped when images are present."""
provider = ScriptedProvider([
LLMResponse(content="API调用参数有误,请检查文档", finish_reason="error"),
LLMResponse(content="ok, no image"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
assert response.content == "ok, no image"
assert provider.calls == 2
msgs_on_retry = provider.last_kwargs["messages"]
for msg in msgs_on_retry:
content = msg.get("content")
if isinstance(content, list):
assert all(b.get("type") != "image_url" for b in content)
assert any("[image: /media/test.png]" in (b.get("text") or "") for b in content)
@pytest.mark.asyncio
async def test_non_transient_error_without_images_no_retry() -> None:
"""Non-transient errors without image content are returned immediately."""
provider = ScriptedProvider([
LLMResponse(content="401 unauthorized", finish_reason="error"),
])
response = await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
)
assert provider.calls == 1
assert response.finish_reason == "error"
@pytest.mark.asyncio
async def test_image_fallback_returns_error_on_second_failure() -> None:
"""If the image-stripped retry also fails, return that error."""
provider = ScriptedProvider([
LLMResponse(content="some model error", finish_reason="error"),
LLMResponse(content="still failing", finish_reason="error"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
assert provider.calls == 2
assert response.content == "still failing"
assert response.finish_reason == "error"
@pytest.mark.asyncio
async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
"""When _meta is absent, fallback placeholder is '[image omitted]'."""
provider = ScriptedProvider([
LLMResponse(content="error", finish_reason="error"),
LLMResponse(content="ok"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG_NO_META)
assert response.content == "ok"
assert provider.calls == 2
msgs_on_retry = provider.last_kwargs["messages"]
for msg in msgs_on_retry:
content = msg.get("content")
if isinstance(content, list):
assert any("[image omitted]" in (b.get("text") or "") for b in content)
@@ -0,0 +1,40 @@
"""Tests for lazy provider exports from nanobot.providers."""
from __future__ import annotations
import importlib
import sys
def test_importing_providers_package_is_lazy(monkeypatch) -> None:
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
providers = importlib.import_module("nanobot.providers")
assert "nanobot.providers.anthropic_provider" not in sys.modules
assert "nanobot.providers.openai_compat_provider" not in sys.modules
assert "nanobot.providers.openai_codex_provider" not in sys.modules
assert "nanobot.providers.azure_openai_provider" not in sys.modules
assert providers.__all__ == [
"LLMProvider",
"LLMResponse",
"AnthropicProvider",
"OpenAICompatProvider",
"OpenAICodexProvider",
"AzureOpenAIProvider",
]
def test_explicit_provider_import_still_works(monkeypatch) -> None:
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
namespace: dict[str, object] = {}
exec("from nanobot.providers import AnthropicProvider", namespace)
assert namespace["AnthropicProvider"].__name__ == "AnthropicProvider"
assert "nanobot.providers.anthropic_provider" in sys.modules
@@ -0,0 +1,101 @@
"""Tests for nanobot.security.network — SSRF protection and internal URL detection."""
from __future__ import annotations
import socket
from unittest.mock import patch
import pytest
from nanobot.security.network import contains_internal_url, validate_url_target
def _fake_resolve(host: str, results: list[str]):
"""Return a getaddrinfo mock that maps the given host to fake IP results."""
def _resolver(hostname, port, family=0, type_=0):
if hostname == host:
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (ip, 0)) for ip in results]
raise socket.gaierror(f"cannot resolve {hostname}")
return _resolver
# ---------------------------------------------------------------------------
# validate_url_target — scheme / domain basics
# ---------------------------------------------------------------------------
def test_rejects_non_http_scheme():
ok, err = validate_url_target("ftp://example.com/file")
assert not ok
assert "http" in err.lower()
def test_rejects_missing_domain():
ok, err = validate_url_target("http://")
assert not ok
# ---------------------------------------------------------------------------
# validate_url_target — blocked private/internal IPs
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("ip,label", [
("127.0.0.1", "loopback"),
("127.0.0.2", "loopback_alt"),
("10.0.0.1", "rfc1918_10"),
("172.16.5.1", "rfc1918_172"),
("192.168.1.1", "rfc1918_192"),
("169.254.169.254", "metadata"),
("0.0.0.0", "zero"),
])
def test_blocks_private_ipv4(ip: str, label: str):
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("evil.com", [ip])):
ok, err = validate_url_target(f"http://evil.com/path")
assert not ok, f"Should block {label} ({ip})"
assert "private" in err.lower() or "blocked" in err.lower()
def test_blocks_ipv6_loopback():
def _resolver(hostname, port, family=0, type_=0):
return [(socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("::1", 0, 0, 0))]
with patch("nanobot.security.network.socket.getaddrinfo", _resolver):
ok, err = validate_url_target("http://evil.com/")
assert not ok
# ---------------------------------------------------------------------------
# validate_url_target — allows public IPs
# ---------------------------------------------------------------------------
def test_allows_public_ip():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("example.com", ["93.184.216.34"])):
ok, err = validate_url_target("http://example.com/page")
assert ok, f"Should allow public IP, got: {err}"
def test_allows_normal_https():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("github.com", ["140.82.121.3"])):
ok, err = validate_url_target("https://github.com/HKUDS/nanobot")
assert ok
# ---------------------------------------------------------------------------
# contains_internal_url — shell command scanning
# ---------------------------------------------------------------------------
def test_detects_curl_metadata():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("169.254.169.254", ["169.254.169.254"])):
assert contains_internal_url('curl -s http://169.254.169.254/computeMetadata/v1/')
def test_detects_wget_localhost():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("localhost", ["127.0.0.1"])):
assert contains_internal_url("wget http://localhost:8080/secret")
def test_allows_normal_curl():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("example.com", ["93.184.216.34"])):
assert not contains_internal_url("curl https://example.com/api/data")
def test_no_urls_returns_false():
assert not contains_internal_url("echo hello && ls -la")
+56
View File
@@ -0,0 +1,56 @@
#!/usr/bin/env bash
set -euo pipefail
cd "$(dirname "$0")/.." || exit 1
IMAGE_NAME="nanobot-test"
echo "=== Building Docker image ==="
docker build -t "$IMAGE_NAME" .
echo ""
echo "=== Running 'nanobot onboard' ==="
docker run --name nanobot-test-run "$IMAGE_NAME" onboard
echo ""
echo "=== Running 'nanobot status' ==="
STATUS_OUTPUT=$(docker commit nanobot-test-run nanobot-test-onboarded > /dev/null && \
docker run --rm nanobot-test-onboarded status 2>&1) || true
echo "$STATUS_OUTPUT"
echo ""
echo "=== Validating output ==="
PASS=true
check() {
if echo "$STATUS_OUTPUT" | grep -q "$1"; then
echo " PASS: found '$1'"
else
echo " FAIL: missing '$1'"
PASS=false
fi
}
check "nanobot Status"
check "Config:"
check "Workspace:"
check "Model:"
check "OpenRouter API:"
check "Anthropic API:"
check "OpenAI API:"
echo ""
if $PASS; then
echo "=== All checks passed ==="
else
echo "=== Some checks FAILED ==="
exit 1
fi
# Cleanup
echo ""
echo "=== Cleanup ==="
docker rm -f nanobot-test-run 2>/dev/null || true
docker rmi -f nanobot-test-onboarded 2>/dev/null || true
docker rmi -f "$IMAGE_NAME" 2>/dev/null || true
echo "Done."
@@ -0,0 +1,69 @@
"""Tests for exec tool internal URL blocking."""
from __future__ import annotations
import socket
from unittest.mock import patch
import pytest
from nanobot.agent.tools.shell import ExecTool
def _fake_resolve_private(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
def _fake_resolve_localhost(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
def _fake_resolve_public(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
@pytest.mark.asyncio
async def test_exec_blocks_curl_metadata():
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
result = await tool.execute(
command='curl -s -H "Metadata-Flavor: Google" http://169.254.169.254/computeMetadata/v1/'
)
assert "Error" in result
assert "internal" in result.lower() or "private" in result.lower()
@pytest.mark.asyncio
async def test_exec_blocks_wget_localhost():
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost):
result = await tool.execute(command="wget http://localhost:8080/secret -O /tmp/out")
assert "Error" in result
@pytest.mark.asyncio
async def test_exec_allows_normal_commands():
tool = ExecTool(timeout=5)
result = await tool.execute(command="echo hello")
assert "hello" in result
assert "Error" not in result.split("\n")[0]
@pytest.mark.asyncio
async def test_exec_allows_curl_to_public_url():
"""Commands with public URLs should not be blocked by the internal URL check."""
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
guard_result = tool._guard_command("curl https://example.com/api", "/tmp")
assert guard_result is None
@pytest.mark.asyncio
async def test_exec_blocks_chained_internal_url():
"""Internal URLs buried in chained commands should still be caught."""
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
result = await tool.execute(
command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
)
assert "Error" in result
@@ -0,0 +1,394 @@
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
import pytest
from nanobot.agent.tools.filesystem import (
EditFileTool,
ListDirTool,
ReadFileTool,
_find_match,
)
# ---------------------------------------------------------------------------
# ReadFileTool
# ---------------------------------------------------------------------------
class TestReadFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.fixture()
def sample_file(self, tmp_path):
f = tmp_path / "sample.txt"
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
return f
@pytest.mark.asyncio
async def test_basic_read_has_line_numbers(self, tool, sample_file):
result = await tool.execute(path=str(sample_file))
assert "1| line 1" in result
assert "20| line 20" in result
@pytest.mark.asyncio
async def test_offset_and_limit(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
assert "5| line 5" in result
assert "7| line 7" in result
assert "8| line 8" not in result
assert "Use offset=8 to continue" in result
@pytest.mark.asyncio
async def test_offset_beyond_end(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=999)
assert "Error" in result
assert "beyond end" in result
@pytest.mark.asyncio
async def test_end_of_file_marker(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
assert "End of file" in result
@pytest.mark.asyncio
async def test_empty_file(self, tool, tmp_path):
f = tmp_path / "empty.txt"
f.write_text("", encoding="utf-8")
result = await tool.execute(path=str(f))
assert "Empty file" in result
@pytest.mark.asyncio
async def test_image_file_returns_multimodal_blocks(self, tool, tmp_path):
f = tmp_path / "pixel.png"
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
result = await tool.execute(path=str(f))
assert isinstance(result, list)
assert result[0]["type"] == "image_url"
assert result[0]["image_url"]["url"].startswith("data:image/png;base64,")
assert result[0]["_meta"]["path"] == str(f)
assert result[1] == {"type": "text", "text": f"(Image file: {f})"}
@pytest.mark.asyncio
async def test_file_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope.txt"))
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_missing_path_returns_clear_error(self, tool):
result = await tool.execute()
assert result == "Error reading file: Unknown path"
@pytest.mark.asyncio
async def test_char_budget_trims(self, tool, tmp_path):
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
f = tmp_path / "big.txt"
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
result = await tool.execute(path=str(f))
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
assert "Use offset=" in result
# ---------------------------------------------------------------------------
# _find_match (unit tests for the helper)
# ---------------------------------------------------------------------------
class TestFindMatch:
def test_exact_match(self):
match, count = _find_match("hello world", "world")
assert match == "world"
assert count == 1
def test_exact_no_match(self):
match, count = _find_match("hello world", "xyz")
assert match is None
assert count == 0
def test_crlf_normalisation(self):
# Caller normalises CRLF before calling _find_match, so test with
# pre-normalised content to verify exact match still works.
content = "line1\nline2\nline3"
old_text = "line1\nline2\nline3"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
def test_line_trim_fallback(self):
content = " def foo():\n pass\n"
old_text = "def foo():\n pass"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
# The returned match should be the *original* indented text
assert " def foo():" in match
def test_line_trim_multiple_candidates(self):
content = " a\n b\n a\n b\n"
old_text = "a\nb"
match, count = _find_match(content, old_text)
assert count == 2
def test_empty_old_text(self):
match, count = _find_match("hello", "")
# Empty string is always "in" any string via exact match
assert match == ""
# ---------------------------------------------------------------------------
# EditFileTool
# ---------------------------------------------------------------------------
class TestEditFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_exact_match(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
assert "Successfully" in result
assert f.read_text() == "hello earth"
@pytest.mark.asyncio
async def test_crlf_normalisation(self, tool, tmp_path):
f = tmp_path / "crlf.py"
f.write_bytes(b"line1\r\nline2\r\nline3")
result = await tool.execute(
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
)
assert "Successfully" in result
raw = f.read_bytes()
assert b"LINE1" in raw
# CRLF line endings should be preserved throughout the file
assert b"\r\n" in raw
@pytest.mark.asyncio
async def test_trim_fallback(self, tool, tmp_path):
f = tmp_path / "indent.py"
f.write_text(" def foo():\n pass\n", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
)
assert "Successfully" in result
assert "bar" in f.read_text()
@pytest.mark.asyncio
async def test_ambiguous_match(self, tool, tmp_path):
f = tmp_path / "dup.py"
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
assert "appears" in result.lower() or "Warning" in result
@pytest.mark.asyncio
async def test_replace_all(self, tool, tmp_path):
f = tmp_path / "multi.py"
f.write_text("foo bar foo bar foo", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="foo", new_text="baz", replace_all=True,
)
assert "Successfully" in result
assert f.read_text() == "baz bar baz bar baz"
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
f = tmp_path / "nf.py"
f.write_text("hello", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_missing_new_text_returns_clear_error(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="hello")
assert result == "Error editing file: Unknown new_text"
# ---------------------------------------------------------------------------
# ListDirTool
# ---------------------------------------------------------------------------
class TestListDirTool:
@pytest.fixture()
def tool(self, tmp_path):
return ListDirTool(workspace=tmp_path)
@pytest.fixture()
def populated_dir(self, tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "main.py").write_text("pass")
(tmp_path / "src" / "utils.py").write_text("pass")
(tmp_path / "README.md").write_text("hi")
(tmp_path / ".git").mkdir()
(tmp_path / ".git" / "config").write_text("x")
(tmp_path / "node_modules").mkdir()
(tmp_path / "node_modules" / "pkg").mkdir()
return tmp_path
@pytest.mark.asyncio
async def test_basic_list(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir))
assert "README.md" in result
assert "src" in result
# .git and node_modules should be ignored
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_recursive(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir), recursive=True)
# Normalize path separators for cross-platform compatibility
normalized = result.replace("\\", "/")
assert "src/main.py" in normalized
assert "src/utils.py" in normalized
assert "README.md" in result
# Ignored dirs should not appear
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_max_entries_truncation(self, tool, tmp_path):
for i in range(10):
(tmp_path / f"file_{i}.txt").write_text("x")
result = await tool.execute(path=str(tmp_path), max_entries=3)
assert "truncated" in result
assert "3 of 10" in result
@pytest.mark.asyncio
async def test_empty_dir(self, tool, tmp_path):
d = tmp_path / "empty"
d.mkdir()
result = await tool.execute(path=str(d))
assert "empty" in result.lower()
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope"))
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_missing_path_returns_clear_error(self, tool):
result = await tool.execute()
assert result == "Error listing directory: Unknown path"
# ---------------------------------------------------------------------------
# Workspace restriction + extra_allowed_dirs
# ---------------------------------------------------------------------------
class TestWorkspaceRestriction:
@pytest.mark.asyncio
async def test_read_blocked_outside_workspace(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
outside = tmp_path / "outside"
outside.mkdir()
secret = outside / "secret.txt"
secret.write_text("top secret")
tool = ReadFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(path=str(secret))
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_read_allowed_with_extra_dir(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
skill_file = skills_dir / "test_skill" / "SKILL.md"
skill_file.parent.mkdir()
skill_file.write_text("# Test Skill\nDo something.")
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(skill_file))
assert "Test Skill" in result
assert "Error" not in result
@pytest.mark.asyncio
async def test_extra_dirs_does_not_widen_write(self, tmp_path):
from nanobot.agent.tools.filesystem import WriteFileTool
workspace = tmp_path / "ws"
workspace.mkdir()
outside = tmp_path / "outside"
outside.mkdir()
tool = WriteFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(path=str(outside / "hack.txt"), content="pwned")
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_read_still_blocked_for_unrelated_dir(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
unrelated = tmp_path / "other"
unrelated.mkdir()
secret = unrelated / "secret.txt"
secret.write_text("nope")
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(secret))
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_workspace_file_still_readable_with_extra_dirs(self, tmp_path):
"""Adding extra_allowed_dirs must not break normal workspace reads."""
workspace = tmp_path / "ws"
workspace.mkdir()
ws_file = workspace / "README.md"
ws_file.write_text("hello from workspace")
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(ws_file))
assert "hello from workspace" in result
assert "Error" not in result
@pytest.mark.asyncio
async def test_edit_blocked_in_extra_dir(self, tmp_path):
"""edit_file must not be able to modify files in extra_allowed_dirs."""
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
skill_file = skills_dir / "weather" / "SKILL.md"
skill_file.parent.mkdir()
skill_file.write_text("# Weather\nOriginal content.")
tool = EditFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(
path=str(skill_file),
old_text="Original content.",
new_text="Hacked content.",
)
assert "Error" in result
assert "outside" in result.lower()
assert skill_file.read_text() == "# Weather\nOriginal content."
+345
View File
@@ -0,0 +1,345 @@
from __future__ import annotations
import asyncio
from contextlib import AsyncExitStack, asynccontextmanager
import sys
from types import ModuleType, SimpleNamespace
import pytest
from nanobot.agent.tools.mcp import MCPToolWrapper, connect_mcp_servers
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.config.schema import MCPServerConfig
class _FakeTextContent:
def __init__(self, text: str) -> None:
self.text = text
@pytest.fixture
def fake_mcp_runtime() -> dict[str, object | None]:
return {"session": None}
@pytest.fixture(autouse=True)
def _fake_mcp_module(
monkeypatch: pytest.MonkeyPatch, fake_mcp_runtime: dict[str, object | None]
) -> None:
mod = ModuleType("mcp")
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
class _FakeStdioServerParameters:
def __init__(self, command: str, args: list[str], env: dict | None = None) -> None:
self.command = command
self.args = args
self.env = env
class _FakeClientSession:
def __init__(self, _read: object, _write: object) -> None:
self._session = fake_mcp_runtime["session"]
async def __aenter__(self) -> object:
return self._session
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
@asynccontextmanager
async def _fake_stdio_client(_params: object):
yield object(), object()
@asynccontextmanager
async def _fake_sse_client(_url: str, httpx_client_factory=None):
yield object(), object()
@asynccontextmanager
async def _fake_streamable_http_client(_url: str, http_client=None):
yield object(), object(), object()
mod.ClientSession = _FakeClientSession
mod.StdioServerParameters = _FakeStdioServerParameters
monkeypatch.setitem(sys.modules, "mcp", mod)
client_mod = ModuleType("mcp.client")
stdio_mod = ModuleType("mcp.client.stdio")
stdio_mod.stdio_client = _fake_stdio_client
sse_mod = ModuleType("mcp.client.sse")
sse_mod.sse_client = _fake_sse_client
streamable_http_mod = ModuleType("mcp.client.streamable_http")
streamable_http_mod.streamable_http_client = _fake_streamable_http_client
monkeypatch.setitem(sys.modules, "mcp.client", client_mod)
monkeypatch.setitem(sys.modules, "mcp.client.stdio", stdio_mod)
monkeypatch.setitem(sys.modules, "mcp.client.sse", sse_mod)
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", streamable_http_mod)
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={"type": "object", "properties": {}},
)
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
def test_wrapper_preserves_non_nullable_unions() -> None:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={
"type": "object",
"properties": {
"value": {
"anyOf": [{"type": "string"}, {"type": "integer"}],
}
},
},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
assert wrapper.parameters["properties"]["value"]["anyOf"] == [
{"type": "string"},
{"type": "integer"},
]
def test_wrapper_normalizes_nullable_property_type_union() -> None:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={
"type": "object",
"properties": {
"name": {"type": ["string", "null"]},
},
},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
assert wrapper.parameters["properties"]["name"] == {"type": "string", "nullable": True}
def test_wrapper_normalizes_nullable_property_anyof() -> None:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={
"type": "object",
"properties": {
"name": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"description": "optional name",
},
},
},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
assert wrapper.parameters["properties"]["name"] == {
"type": "string",
"description": "optional name",
"nullable": True,
}
@pytest.mark.asyncio
async def test_execute_returns_text_blocks() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
assert arguments == {"value": 1}
return SimpleNamespace(content=[_FakeTextContent("hello"), 42])
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
result = await wrapper.execute(value=1)
assert result == "hello\n42"
@pytest.mark.asyncio
async def test_execute_returns_timeout_message() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
await asyncio.sleep(1)
return SimpleNamespace(content=[])
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=0.01)
result = await wrapper.execute()
assert result == "(MCP tool call timed out after 0.01s)"
@pytest.mark.asyncio
async def test_execute_handles_server_cancelled_error() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
raise asyncio.CancelledError()
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
result = await wrapper.execute()
assert result == "(MCP tool call was cancelled)"
@pytest.mark.asyncio
async def test_execute_re_raises_external_cancellation() -> None:
started = asyncio.Event()
async def call_tool(_name: str, arguments: dict) -> object:
started.set()
await asyncio.sleep(60)
return SimpleNamespace(content=[])
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
task = asyncio.create_task(wrapper.execute())
await started.wait()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_execute_handles_generic_exception() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
raise RuntimeError("boom")
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
result = await wrapper.execute()
assert result == "(MCP tool call failed: RuntimeError)"
def _make_tool_def(name: str) -> SimpleNamespace:
return SimpleNamespace(
name=name,
description=f"{name} tool",
inputSchema={"type": "object", "properties": {}},
)
def _make_fake_session(tool_names: list[str]) -> SimpleNamespace:
async def initialize() -> None:
return None
async def list_tools() -> SimpleNamespace:
return SimpleNamespace(tools=[_make_tool_def(name) for name in tool_names])
return SimpleNamespace(initialize=initialize, list_tools=list_tools)
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"]
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"]
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == []
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
fake_mcp_runtime: dict[str, object | None], monkeypatch: pytest.MonkeyPatch
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo"])
registry = ToolRegistry()
warnings: list[str] = []
def _warning(message: str, *args: object) -> None:
warnings.append(message.format(*args))
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == []
assert warnings
assert "enabledTools entries not found: unknown" in warnings[-1]
assert "Available raw names: demo" in warnings[-1]
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
@@ -0,0 +1,10 @@
import pytest
from nanobot.agent.tools.message import MessageTool
@pytest.mark.asyncio
async def test_message_tool_returns_error_when_no_target_context() -> None:
tool = MessageTool()
result = await tool.execute(content="test")
assert result == "Error: No target channel/chat specified"
@@ -0,0 +1,132 @@
"""Test message tool suppress logic for final replies."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools.message import MessageTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
class TestMessageToolSuppressLogic:
"""Final reply suppressed only when message tool sends to the same target."""
@pytest.mark.asyncio
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
result = await loop._process_message(msg)
assert len(sent) == 1
assert result is None # suppressed
@pytest.mark.asyncio
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
result = await loop._process_message(msg)
assert len(sent) == 1
assert sent[0].channel == "email"
assert result is not None # not suppressed
assert result.channel == "feishu"
@pytest.mark.asyncio
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
result = await loop._process_message(msg)
assert result is not None
assert "Hello" in result.content
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
calls = iter([
LLMResponse(
content="Visible<think>hidden</think>",
tool_calls=[tool_call],
reasoning_content="secret reasoning",
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
progress: list[tuple[str, bool]] = []
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
progress.append((content, tool_hint))
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert progress == [
("Visible", False),
('read_file("foo.txt")', True),
]
class TestMessageToolTurnTracking:
def test_sent_in_turn_tracks_same_target(self) -> None:
tool = MessageTool()
tool.set_context("feishu", "chat1")
assert not tool._sent_in_turn
tool._sent_in_turn = True
assert tool._sent_in_turn
def test_start_turn_resets(self) -> None:
tool = MessageTool()
tool._sent_in_turn = True
tool.start_turn()
assert not tool._sent_in_turn
@@ -0,0 +1,481 @@
from typing import Any
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
class SampleTool(Tool):
@property
def name(self) -> str:
return "sample"
@property
def description(self) -> str:
return "sample tool"
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {"type": "string", "minLength": 2},
"count": {"type": "integer", "minimum": 1, "maximum": 10},
"mode": {"type": "string", "enum": ["fast", "full"]},
"meta": {
"type": "object",
"properties": {
"tag": {"type": "string"},
"flags": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["tag"],
},
},
"required": ["query", "count"],
}
async def execute(self, **kwargs: Any) -> str:
return "ok"
def test_validate_params_missing_required() -> None:
tool = SampleTool()
errors = tool.validate_params({"query": "hi"})
assert "missing required count" in "; ".join(errors)
def test_validate_params_type_and_range() -> None:
tool = SampleTool()
errors = tool.validate_params({"query": "hi", "count": 0})
assert any("count must be >= 1" in e for e in errors)
errors = tool.validate_params({"query": "hi", "count": "2"})
assert any("count should be integer" in e for e in errors)
def test_validate_params_enum_and_min_length() -> None:
tool = SampleTool()
errors = tool.validate_params({"query": "h", "count": 2, "mode": "slow"})
assert any("query must be at least 2 chars" in e for e in errors)
assert any("mode must be one of" in e for e in errors)
def test_validate_params_nested_object_and_array() -> None:
tool = SampleTool()
errors = tool.validate_params(
{
"query": "hi",
"count": 2,
"meta": {"flags": [1, "ok"]},
}
)
assert any("missing required meta.tag" in e for e in errors)
assert any("meta.flags[0] should be string" in e for e in errors)
def test_validate_params_ignores_unknown_fields() -> None:
tool = SampleTool()
errors = tool.validate_params({"query": "hi", "count": 2, "extra": "x"})
assert errors == []
async def test_registry_returns_validation_error() -> None:
reg = ToolRegistry()
reg.register(SampleTool())
result = await reg.execute("sample", {"query": "hi"})
assert "Invalid parameters" in result
def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
cmd = r"type C:\user\workspace\txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert paths == [r"C:\user\workspace\txt"]
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
cmd = ".venv/bin/python script.py"
paths = ExecTool._extract_absolute_paths(cmd)
assert "/bin/python" not in paths
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
cmd = "cat /tmp/data.txt > /tmp/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "/tmp/out.txt" in paths
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert "~/.nanobot/config.json" in paths
assert "~/out.txt" in paths
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "~/.nanobot/config.json" in paths
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
# --- cast_params tests ---
class CastTestTool(Tool):
"""Minimal tool for testing cast_params."""
def __init__(self, schema: dict[str, Any]) -> None:
self._schema = schema
@property
def name(self) -> str:
return "cast_test"
@property
def description(self) -> str:
return "test tool for casting"
@property
def parameters(self) -> dict[str, Any]:
return self._schema
async def execute(self, **kwargs: Any) -> str:
return "ok"
def test_cast_params_string_to_int() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "42"})
assert result["count"] == 42
assert isinstance(result["count"], int)
def test_cast_params_string_to_number() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "3.14"})
assert result["rate"] == 3.14
assert isinstance(result["rate"], float)
def test_cast_params_string_to_bool() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"enabled": {"type": "boolean"}},
}
)
assert tool.cast_params({"enabled": "true"})["enabled"] is True
assert tool.cast_params({"enabled": "false"})["enabled"] is False
assert tool.cast_params({"enabled": "1"})["enabled"] is True
def test_cast_params_array_items() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"nums": {"type": "array", "items": {"type": "integer"}},
},
}
)
result = tool.cast_params({"nums": ["1", "2", "3"]})
assert result["nums"] == [1, 2, 3]
def test_cast_params_nested_object() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"port": {"type": "integer"},
"debug": {"type": "boolean"},
},
},
},
}
)
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
assert result["config"]["port"] == 8080
assert result["config"]["debug"] is True
def test_cast_params_bool_not_cast_to_int() -> None:
"""Booleans should not be silently cast to integers."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": True})
assert result["count"] is True
errors = tool.validate_params(result)
assert any("count should be integer" in e for e in errors)
def test_cast_params_preserves_empty_string() -> None:
"""Empty strings should be preserved for string type."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": "string"}},
}
)
result = tool.cast_params({"name": ""})
assert result["name"] == ""
def test_cast_params_bool_string_false() -> None:
"""Test that 'false', '0', 'no' strings convert to False."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
assert tool.cast_params({"flag": "false"})["flag"] is False
assert tool.cast_params({"flag": "False"})["flag"] is False
assert tool.cast_params({"flag": "0"})["flag"] is False
assert tool.cast_params({"flag": "no"})["flag"] is False
assert tool.cast_params({"flag": "NO"})["flag"] is False
def test_cast_params_bool_string_invalid() -> None:
"""Invalid boolean strings should not be cast."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
# Invalid strings should be preserved (validation will catch them)
result = tool.cast_params({"flag": "random"})
assert result["flag"] == "random"
result = tool.cast_params({"flag": "maybe"})
assert result["flag"] == "maybe"
def test_cast_params_invalid_string_to_int() -> None:
"""Invalid strings should not be cast to integer."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "abc"})
assert result["count"] == "abc" # Original value preserved
result = tool.cast_params({"count": "12.5.7"})
assert result["count"] == "12.5.7"
def test_cast_params_invalid_string_to_number() -> None:
"""Invalid strings should not be cast to number."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "not_a_number"})
assert result["rate"] == "not_a_number"
def test_validate_params_bool_not_accepted_as_number() -> None:
"""Booleans should not pass number validation."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
errors = tool.validate_params({"rate": False})
assert any("rate should be number" in e for e in errors)
def test_cast_params_none_values() -> None:
"""Test None handling for different types."""
tool = CastTestTool(
{
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "integer"},
"items": {"type": "array"},
"config": {"type": "object"},
},
}
)
result = tool.cast_params(
{
"name": None,
"count": None,
"items": None,
"config": None,
}
)
# None should be preserved for all types
assert result["name"] is None
assert result["count"] is None
assert result["items"] is None
assert result["config"] is None
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
"""Single values should NOT be automatically wrapped into arrays."""
tool = CastTestTool(
{
"type": "object",
"properties": {"items": {"type": "array"}},
}
)
# Non-array values should be preserved (validation will catch them)
result = tool.cast_params({"items": 5})
assert result["items"] == 5 # Not wrapped to [5]
result = tool.cast_params({"items": "text"})
assert result["items"] == "text" # Not wrapped to ["text"]
# --- ExecTool enhancement tests ---
async def test_exec_always_returns_exit_code() -> None:
"""Exit code should appear in output even on success (exit 0)."""
tool = ExecTool()
result = await tool.execute(command="echo hello")
assert "Exit code: 0" in result
assert "hello" in result
async def test_exec_head_tail_truncation() -> None:
"""Long output should preserve both head and tail."""
tool = ExecTool()
# Generate output that exceeds _MAX_OUTPUT (10_000 chars)
# Use python to generate output to avoid command line length limits
result = await tool.execute(
command="python -c \"print('A' * 6000 + '\\n' + 'B' * 6000)\""
)
assert "chars truncated" in result
# Head portion should start with As
assert result.startswith("A")
# Tail portion should end with the exit code which comes after Bs
assert "Exit code:" in result
async def test_exec_timeout_parameter() -> None:
"""LLM-supplied timeout should override the constructor default."""
tool = ExecTool(timeout=60)
# A very short timeout should cause the command to be killed
result = await tool.execute(command="sleep 10", timeout=1)
assert "timed out" in result
assert "1 seconds" in result
async def test_exec_timeout_capped_at_max() -> None:
"""Timeout values above _MAX_TIMEOUT should be clamped."""
tool = ExecTool()
# Should not raise — just clamp to 600
result = await tool.execute(command="echo ok", timeout=9999)
assert "Exit code: 0" in result
# --- _resolve_type and nullable param tests ---
def test_resolve_type_simple_string() -> None:
"""Simple string type passes through unchanged."""
assert Tool._resolve_type("string") == "string"
def test_resolve_type_union_with_null() -> None:
"""Union type ['string', 'null'] resolves to 'string'."""
assert Tool._resolve_type(["string", "null"]) == "string"
def test_resolve_type_only_null() -> None:
"""Union type ['null'] resolves to None (no non-null type)."""
assert Tool._resolve_type(["null"]) is None
def test_resolve_type_none_input() -> None:
"""None input passes through as None."""
assert Tool._resolve_type(None) is None
def test_validate_nullable_param_accepts_string() -> None:
"""Nullable string param should accept a string value."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": ["string", "null"]}},
}
)
errors = tool.validate_params({"name": "hello"})
assert errors == []
def test_validate_nullable_param_accepts_none() -> None:
"""Nullable string param should accept None."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": ["string", "null"]}},
}
)
errors = tool.validate_params({"name": None})
assert errors == []
def test_validate_nullable_flag_accepts_none() -> None:
"""OpenAI-normalized nullable params should still accept None locally."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": "string", "nullable": True}},
}
)
errors = tool.validate_params({"name": None})
assert errors == []
def test_cast_nullable_param_no_crash() -> None:
"""cast_params should not crash on nullable type (the original bug)."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": ["string", "null"]}},
}
)
result = tool.cast_params({"name": "hello"})
assert result["name"] == "hello"
result = tool.cast_params({"name": None})
assert result["name"] is None
@@ -0,0 +1,113 @@
"""Tests for web_fetch SSRF protection and untrusted content marking."""
from __future__ import annotations
import json
import socket
from unittest.mock import patch
import pytest
from nanobot.agent.tools.web import WebFetchTool
def _fake_resolve_private(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
def _fake_resolve_public(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
@pytest.mark.asyncio
async def test_web_fetch_blocks_private_ip():
tool = WebFetchTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
result = await tool.execute(url="http://169.254.169.254/computeMetadata/v1/")
data = json.loads(result)
assert "error" in data
assert "private" in data["error"].lower() or "blocked" in data["error"].lower()
@pytest.mark.asyncio
async def test_web_fetch_blocks_localhost():
tool = WebFetchTool()
def _resolve_localhost(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
with patch("nanobot.security.network.socket.getaddrinfo", _resolve_localhost):
result = await tool.execute(url="http://localhost/admin")
data = json.loads(result)
assert "error" in data
@pytest.mark.asyncio
async def test_web_fetch_result_contains_untrusted_flag():
"""When fetch succeeds, result JSON must include untrusted=True and the banner."""
tool = WebFetchTool()
fake_html = "<html><head><title>Test</title></head><body><p>Hello world</p></body></html>"
import httpx
class FakeResponse:
status_code = 200
url = "https://example.com/page"
text = fake_html
headers = {"content-type": "text/html"}
def raise_for_status(self): pass
def json(self): return {}
async def _fake_get(self, url, **kwargs):
return FakeResponse()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public), \
patch("httpx.AsyncClient.get", _fake_get):
result = await tool.execute(url="https://example.com/page")
data = json.loads(result)
assert data.get("untrusted") is True
assert "[External content" in data.get("text", "")
@pytest.mark.asyncio
async def test_web_fetch_blocks_private_redirect_before_returning_image(monkeypatch):
tool = WebFetchTool()
class FakeStreamResponse:
headers = {"content-type": "image/png"}
url = "http://127.0.0.1/secret.png"
content = b"\x89PNG\r\n\x1a\n"
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def aread(self):
return self.content
def raise_for_status(self):
return None
class FakeClient:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
def stream(self, method, url, headers=None):
return FakeStreamResponse()
monkeypatch.setattr("nanobot.agent.tools.web.httpx.AsyncClient", FakeClient)
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
result = await tool.execute(url="https://example.com/image.png")
data = json.loads(result)
assert "error" in data
assert "redirect blocked" in data["error"].lower()
@@ -0,0 +1,162 @@
"""Tests for multi-provider web search."""
import httpx
import pytest
from nanobot.agent.tools.web import WebSearchTool
from nanobot.config.schema import WebSearchConfig
def _tool(provider: str = "brave", api_key: str = "", base_url: str = "") -> WebSearchTool:
return WebSearchTool(config=WebSearchConfig(provider=provider, api_key=api_key, base_url=base_url))
def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
"""Build a mock httpx.Response with a dummy request attached."""
r = httpx.Response(status, json=json)
r._request = httpx.Request("GET", "https://mock")
return r
@pytest.mark.asyncio
async def test_brave_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "brave" in url
assert kw["headers"]["X-Subscription-Token"] == "brave-key"
return _response(json={
"web": {"results": [{"title": "NanoBot", "url": "https://example.com", "description": "AI assistant"}]}
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="brave", api_key="brave-key")
result = await tool.execute(query="nanobot", count=1)
assert "NanoBot" in result
assert "https://example.com" in result
@pytest.mark.asyncio
async def test_tavily_search(monkeypatch):
async def mock_post(self, url, **kw):
assert "tavily" in url
assert kw["headers"]["Authorization"] == "Bearer tavily-key"
return _response(json={
"results": [{"title": "OpenClaw", "url": "https://openclaw.io", "content": "Framework"}]
})
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
tool = _tool(provider="tavily", api_key="tavily-key")
result = await tool.execute(query="openclaw")
assert "OpenClaw" in result
assert "https://openclaw.io" in result
@pytest.mark.asyncio
async def test_searxng_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "searx.example" in url
return _response(json={
"results": [{"title": "Result", "url": "https://example.com", "content": "SearXNG result"}]
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="searxng", base_url="https://searx.example")
result = await tool.execute(query="test")
assert "Result" in result
@pytest.mark.asyncio
async def test_duckduckgo_search(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "DDG Result", "href": "https://ddg.example", "body": "From DuckDuckGo"}]
monkeypatch.setattr("nanobot.agent.tools.web.DDGS", MockDDGS, raising=False)
import nanobot.agent.tools.web as web_mod
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
from ddgs import DDGS
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
tool = _tool(provider="duckduckgo")
result = await tool.execute(query="hello")
assert "DDG Result" in result
@pytest.mark.asyncio
async def test_brave_fallback_to_duckduckgo_when_no_key(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
tool = _tool(provider="brave", api_key="")
result = await tool.execute(query="test")
assert "Fallback" in result
@pytest.mark.asyncio
async def test_jina_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "s.jina.ai" in str(url)
assert kw["headers"]["Authorization"] == "Bearer jina-key"
return _response(json={
"data": [{"title": "Jina Result", "url": "https://jina.ai", "content": "AI search"}]
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="jina", api_key="jina-key")
result = await tool.execute(query="test")
assert "Jina Result" in result
assert "https://jina.ai" in result
@pytest.mark.asyncio
async def test_unknown_provider():
tool = _tool(provider="unknown")
result = await tool.execute(query="test")
assert "unknown" in result
assert "Error" in result
@pytest.mark.asyncio
async def test_default_provider_is_brave(monkeypatch):
async def mock_get(self, url, **kw):
assert "brave" in url
return _response(json={"web": {"results": []}})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="", api_key="test-key")
result = await tool.execute(query="test")
assert "No results" in result
@pytest.mark.asyncio
async def test_searxng_no_base_url_falls_back(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "Fallback", "href": "https://ddg.example", "body": "fallback"}]
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
monkeypatch.delenv("SEARXNG_BASE_URL", raising=False)
tool = _tool(provider="searxng", base_url="")
result = await tool.execute(query="test")
assert "Fallback" in result
@pytest.mark.asyncio
async def test_searxng_invalid_url():
tool = _tool(provider="searxng", base_url="not-a-url")
result = await tool.execute(query="test")
assert "Error" in result