Update 2026-05-13 16:43:53
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
|
||||
|
||||
class _DummyChannel(BaseChannel):
|
||||
name = "dummy"
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop(self) -> None:
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_is_allowed_requires_exact_match() -> None:
|
||||
channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("allow@email.com") is True
|
||||
assert channel.is_allowed("attacker|allow@email.com") is False
|
||||
@@ -0,0 +1,298 @@
|
||||
"""Tests for ChannelManager delta coalescing to reduce streaming latency."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
class MockChannel(BaseChannel):
|
||||
"""Mock channel for testing."""
|
||||
|
||||
name = "mock"
|
||||
display_name = "Mock"
|
||||
|
||||
def __init__(self, config, bus):
|
||||
super().__init__(config, bus)
|
||||
self._send_delta_mock = AsyncMock()
|
||||
self._send_mock = AsyncMock()
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
async def stop(self):
|
||||
pass
|
||||
|
||||
async def send(self, msg):
|
||||
"""Implement abstract method."""
|
||||
return await self._send_mock(msg)
|
||||
|
||||
async def send_delta(self, chat_id, delta, metadata=None):
|
||||
"""Override send_delta for testing."""
|
||||
return await self._send_delta_mock(chat_id, delta, metadata)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
"""Create a minimal config for testing."""
|
||||
return Config()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bus():
|
||||
"""Create a message bus for testing."""
|
||||
return MessageBus()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(config, bus):
|
||||
"""Create a channel manager with a mock channel."""
|
||||
manager = ChannelManager(config, bus)
|
||||
manager.channels["mock"] = MockChannel({}, bus)
|
||||
return manager
|
||||
|
||||
|
||||
class TestDeltaCoalescing:
|
||||
"""Tests for _stream_delta message coalescing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_delta_not_coalesced(self, manager, bus):
|
||||
"""A single delta should be sent as-is."""
|
||||
msg = OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Hello",
|
||||
metadata={"_stream_delta": True},
|
||||
)
|
||||
await bus.publish_outbound(msg)
|
||||
|
||||
# Process one message
|
||||
async def process_one():
|
||||
try:
|
||||
m = await asyncio.wait_for(bus.consume_outbound(), timeout=0.1)
|
||||
if m.metadata.get("_stream_delta"):
|
||||
m, pending = manager._coalesce_stream_deltas(m)
|
||||
# Put pending back (none expected)
|
||||
for p in pending:
|
||||
await bus.publish_outbound(p)
|
||||
channel = manager.channels.get(m.channel)
|
||||
if channel:
|
||||
await channel.send_delta(m.chat_id, m.content, m.metadata)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
await process_one()
|
||||
|
||||
manager.channels["mock"]._send_delta_mock.assert_called_once_with(
|
||||
"chat1", "Hello", {"_stream_delta": True}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_deltas_coalesced(self, manager, bus):
|
||||
"""Multiple consecutive deltas for same chat should be merged."""
|
||||
# Put multiple deltas in queue
|
||||
for text in ["Hello", " ", "world", "!"]:
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content=text,
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
|
||||
# Process using coalescing logic
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
# Should have merged all deltas
|
||||
assert merged.content == "Hello world!"
|
||||
assert merged.metadata.get("_stream_delta") is True
|
||||
# No pending messages (all were coalesced)
|
||||
assert len(pending) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deltas_different_chats_not_coalesced(self, manager, bus):
|
||||
"""Deltas for different chats should not be merged."""
|
||||
# Put deltas for different chats
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Hello",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat2",
|
||||
content="World",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
# First chat should not include second chat's content
|
||||
assert merged.content == "Hello"
|
||||
assert merged.chat_id == "chat1"
|
||||
# Second chat should be in pending
|
||||
assert len(pending) == 1
|
||||
assert pending[0].chat_id == "chat2"
|
||||
assert pending[0].content == "World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_terminates_coalescing(self, manager, bus):
|
||||
"""_stream_end should stop coalescing and be included in final message."""
|
||||
# Put deltas with stream_end at the end
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Hello",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content=" world",
|
||||
metadata={"_stream_delta": True, "_stream_end": True},
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
# Should have merged content
|
||||
assert merged.content == "Hello world"
|
||||
# Should have stream_end flag
|
||||
assert merged.metadata.get("_stream_end") is True
|
||||
# No pending
|
||||
assert len(pending) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_coalescing_stops_at_first_non_matching_boundary(self, manager, bus):
|
||||
"""Only consecutive deltas should be merged; later deltas stay queued."""
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Hello",
|
||||
metadata={"_stream_delta": True, "_stream_id": "seg-1"},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="",
|
||||
metadata={"_stream_end": True, "_stream_id": "seg-1"},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="world",
|
||||
metadata={"_stream_delta": True, "_stream_id": "seg-2"},
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
assert merged.content == "Hello"
|
||||
assert merged.metadata.get("_stream_end") is None
|
||||
assert len(pending) == 1
|
||||
assert pending[0].metadata.get("_stream_end") is True
|
||||
assert pending[0].metadata.get("_stream_id") == "seg-1"
|
||||
|
||||
# The next stream segment must remain in queue order for later dispatch.
|
||||
remaining = await bus.consume_outbound()
|
||||
assert remaining.content == "world"
|
||||
assert remaining.metadata.get("_stream_id") == "seg-2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_delta_message_preserved(self, manager, bus):
|
||||
"""Non-delta messages should be preserved in pending list."""
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Delta",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Final message",
|
||||
metadata={}, # Not a delta
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
assert merged.content == "Delta"
|
||||
assert len(pending) == 1
|
||||
assert pending[0].content == "Final message"
|
||||
assert pending[0].metadata.get("_stream_delta") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_queue_stops_coalescing(self, manager, bus):
|
||||
"""Coalescing should stop when queue is empty."""
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Only message",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
|
||||
first_msg = await bus.consume_outbound()
|
||||
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||
|
||||
assert merged.content == "Only message"
|
||||
assert len(pending) == 0
|
||||
|
||||
|
||||
class TestDispatchOutboundWithCoalescing:
|
||||
"""Tests for the full _dispatch_outbound flow with coalescing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_coalesces_and_processes_pending(self, manager, bus):
|
||||
"""_dispatch_outbound should coalesce deltas and process pending messages."""
|
||||
# Put multiple deltas followed by a regular message
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="A",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="B",
|
||||
metadata={"_stream_delta": True},
|
||||
))
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel="mock",
|
||||
chat_id="chat1",
|
||||
content="Final",
|
||||
metadata={}, # Regular message
|
||||
))
|
||||
|
||||
# Run one iteration of dispatch logic manually
|
||||
pending = []
|
||||
processed = []
|
||||
|
||||
# First iteration: should coalesce A+B
|
||||
if pending:
|
||||
msg = pending.pop(0)
|
||||
else:
|
||||
msg = await bus.consume_outbound()
|
||||
|
||||
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):
|
||||
msg, extra_pending = manager._coalesce_stream_deltas(msg)
|
||||
pending.extend(extra_pending)
|
||||
|
||||
channel = manager.channels.get(msg.channel)
|
||||
if channel:
|
||||
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
|
||||
processed.append(("delta", msg.content))
|
||||
|
||||
# Should have sent coalesced delta
|
||||
assert processed == [("delta", "AB")]
|
||||
# Should have pending regular message
|
||||
assert len(pending) == 1
|
||||
assert pending[0].content == "Final"
|
||||
@@ -0,0 +1,880 @@
|
||||
"""Tests for channel plugin discovery, merging, and config compatibility."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.schema import ChannelsConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakePlugin(BaseChannel):
|
||||
name = "fakeplugin"
|
||||
display_name = "Fake Plugin"
|
||||
|
||||
def __init__(self, config, bus):
|
||||
super().__init__(config, bus)
|
||||
self.login_calls: list[bool] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
async def login(self, force: bool = False) -> bool:
|
||||
self.login_calls.append(force)
|
||||
return True
|
||||
|
||||
|
||||
class _FakeTelegram(BaseChannel):
|
||||
"""Plugin that tries to shadow built-in telegram."""
|
||||
name = "telegram"
|
||||
display_name = "Fake Telegram"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _make_entry_point(name: str, cls: type):
|
||||
"""Create a mock entry point that returns *cls* on load()."""
|
||||
ep = SimpleNamespace(name=name, load=lambda _cls=cls: _cls)
|
||||
return ep
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelsConfig extra="allow"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_channels_config_accepts_unknown_keys():
|
||||
cfg = ChannelsConfig.model_validate({
|
||||
"myplugin": {"enabled": True, "token": "abc"},
|
||||
})
|
||||
extra = cfg.model_extra
|
||||
assert extra is not None
|
||||
assert extra["myplugin"]["enabled"] is True
|
||||
assert extra["myplugin"]["token"] == "abc"
|
||||
|
||||
|
||||
def test_channels_config_getattr_returns_extra():
|
||||
cfg = ChannelsConfig.model_validate({"myplugin": {"enabled": True}})
|
||||
section = getattr(cfg, "myplugin", None)
|
||||
assert isinstance(section, dict)
|
||||
assert section["enabled"] is True
|
||||
|
||||
|
||||
def test_channels_config_builtin_fields_removed():
|
||||
"""After decoupling, ChannelsConfig has no explicit channel fields."""
|
||||
cfg = ChannelsConfig()
|
||||
assert not hasattr(cfg, "telegram")
|
||||
assert cfg.send_progress is True
|
||||
assert cfg.send_tool_hints is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# discover_plugins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EP_TARGET = "importlib.metadata.entry_points"
|
||||
|
||||
|
||||
def test_discover_plugins_loads_entry_points():
|
||||
from nanobot.channels.registry import discover_plugins
|
||||
|
||||
ep = _make_entry_point("line", _FakePlugin)
|
||||
with patch(_EP_TARGET, return_value=[ep]):
|
||||
result = discover_plugins()
|
||||
|
||||
assert "line" in result
|
||||
assert result["line"] is _FakePlugin
|
||||
|
||||
|
||||
def test_discover_plugins_handles_load_error():
|
||||
from nanobot.channels.registry import discover_plugins
|
||||
|
||||
def _boom():
|
||||
raise RuntimeError("broken")
|
||||
|
||||
ep = SimpleNamespace(name="broken", load=_boom)
|
||||
with patch(_EP_TARGET, return_value=[ep]):
|
||||
result = discover_plugins()
|
||||
|
||||
assert "broken" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# discover_all — merge & priority
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_discover_all_includes_builtins():
|
||||
from nanobot.channels.registry import discover_all, discover_channel_names
|
||||
|
||||
with patch(_EP_TARGET, return_value=[]):
|
||||
result = discover_all()
|
||||
|
||||
# discover_all() only returns channels that are actually available (dependencies installed)
|
||||
# discover_channel_names() returns all built-in channel names
|
||||
# So we check that all actually loaded channels are in the result
|
||||
for name in result:
|
||||
assert name in discover_channel_names()
|
||||
|
||||
|
||||
def test_discover_all_includes_external_plugin():
|
||||
from nanobot.channels.registry import discover_all
|
||||
|
||||
ep = _make_entry_point("line", _FakePlugin)
|
||||
with patch(_EP_TARGET, return_value=[ep]):
|
||||
result = discover_all()
|
||||
|
||||
assert "line" in result
|
||||
assert result["line"] is _FakePlugin
|
||||
|
||||
|
||||
def test_discover_all_builtin_shadows_plugin():
|
||||
from nanobot.channels.registry import discover_all
|
||||
|
||||
ep = _make_entry_point("telegram", _FakeTelegram)
|
||||
with patch(_EP_TARGET, return_value=[ep]):
|
||||
result = discover_all()
|
||||
|
||||
assert "telegram" in result
|
||||
assert result["telegram"] is not _FakeTelegram
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Manager _init_channels with dict config (plugin scenario)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_loads_plugin_from_dict_config():
|
||||
"""ChannelManager should instantiate a plugin channel from a raw dict config."""
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig.model_validate({
|
||||
"fakeplugin": {"enabled": True, "allowFrom": ["*"]},
|
||||
}),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
return_value={"fakeplugin": _FakePlugin},
|
||||
):
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {}
|
||||
mgr._dispatch_task = None
|
||||
mgr._init_channels()
|
||||
|
||||
assert "fakeplugin" in mgr.channels
|
||||
assert isinstance(mgr.channels["fakeplugin"], _FakePlugin)
|
||||
|
||||
|
||||
def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from typer.testing import CliRunner
|
||||
|
||||
runner = CliRunner()
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
class _LoginPlugin(_FakePlugin):
|
||||
display_name = "Login Plugin"
|
||||
|
||||
async def login(self, force: bool = False) -> bool:
|
||||
seen["force"] = force
|
||||
seen["config"] = self.config
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {"fakeplugin": _LoginPlugin},
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["channels", "login", "fakeplugin", "--force"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["force"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_skips_disabled_plugin():
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig.model_validate({
|
||||
"fakeplugin": {"enabled": False},
|
||||
}),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
return_value={"fakeplugin": _FakePlugin},
|
||||
):
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {}
|
||||
mgr._dispatch_task = None
|
||||
mgr._init_channels()
|
||||
|
||||
assert "fakeplugin" not in mgr.channels
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in channel default_config() and dict->Pydantic conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_builtin_channel_default_config():
|
||||
"""Built-in channels expose default_config() returning a dict with 'enabled': False."""
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
cfg = TelegramChannel.default_config()
|
||||
assert isinstance(cfg, dict)
|
||||
assert cfg["enabled"] is False
|
||||
assert "token" in cfg
|
||||
|
||||
|
||||
def test_builtin_channel_init_from_dict():
|
||||
"""Built-in channels accept a raw dict and convert to Pydantic internally."""
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
bus = MessageBus()
|
||||
ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus)
|
||||
assert ch.config.token == "test-tok"
|
||||
assert ch.config.allow_from == ["*"]
|
||||
|
||||
|
||||
def test_channels_config_send_max_retries_default():
|
||||
"""ChannelsConfig should have send_max_retries with default value of 3."""
|
||||
cfg = ChannelsConfig()
|
||||
assert hasattr(cfg, 'send_max_retries')
|
||||
assert cfg.send_max_retries == 3
|
||||
|
||||
|
||||
def test_channels_config_send_max_retries_upper_bound():
|
||||
"""send_max_retries should be bounded to prevent resource exhaustion."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
# Value too high should be rejected
|
||||
with pytest.raises(ValidationError):
|
||||
ChannelsConfig(send_max_retries=100)
|
||||
|
||||
# Negative should be rejected
|
||||
with pytest.raises(ValidationError):
|
||||
ChannelsConfig(send_max_retries=-1)
|
||||
|
||||
# Boundary values should be allowed
|
||||
cfg_min = ChannelsConfig(send_max_retries=0)
|
||||
assert cfg_min.send_max_retries == 0
|
||||
|
||||
cfg_max = ChannelsConfig(send_max_retries=10)
|
||||
assert cfg_max.send_max_retries == 10
|
||||
|
||||
# Value above upper bound should be rejected
|
||||
with pytest.raises(ValidationError):
|
||||
ChannelsConfig(send_max_retries=11)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_succeeds_first_try():
|
||||
"""_send_with_retry should succeed on first try and not retry."""
|
||||
call_count = 0
|
||||
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# Succeeds on first try
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
|
||||
await mgr._send_with_retry(mgr.channels["failing"], msg)
|
||||
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_retries_on_failure():
|
||||
"""_send_with_retry should retry on failure up to max_retries times."""
|
||||
call_count = 0
|
||||
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError("simulated failure")
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
|
||||
|
||||
# Patch asyncio.sleep to avoid actual delays
|
||||
with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
await mgr._send_with_retry(mgr.channels["failing"], msg)
|
||||
|
||||
assert call_count == 3 # 3 total attempts (initial + 2 retries)
|
||||
assert mock_sleep.call_count == 2 # 2 sleeps between retries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_no_retry_when_max_is_zero():
|
||||
"""_send_with_retry should not retry when send_max_retries is 0."""
|
||||
call_count = 0
|
||||
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError("simulated failure")
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=0),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
|
||||
|
||||
with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock):
|
||||
await mgr._send_with_retry(mgr.channels["failing"], msg)
|
||||
|
||||
assert call_count == 1 # Called once but no retry (max(0, 1) = 1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_calls_send_delta():
|
||||
"""_send_with_retry should call send_delta when metadata has _stream_delta."""
|
||||
send_delta_called = False
|
||||
|
||||
class _StreamingChannel(BaseChannel):
|
||||
name = "streaming"
|
||||
display_name = "Streaming"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass # Should not be called
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None:
|
||||
nonlocal send_delta_called
|
||||
send_delta_called = True
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"streaming": _StreamingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="streaming", chat_id="123", content="test delta",
|
||||
metadata={"_stream_delta": True}
|
||||
)
|
||||
await mgr._send_with_retry(mgr.channels["streaming"], msg)
|
||||
|
||||
assert send_delta_called is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_skips_send_when_streamed():
|
||||
"""_send_with_retry should not call send when metadata has _streamed flag."""
|
||||
send_called = False
|
||||
send_delta_called = False
|
||||
|
||||
class _StreamedChannel(BaseChannel):
|
||||
name = "streamed"
|
||||
display_name = "Streamed"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal send_called
|
||||
send_called = True
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None:
|
||||
nonlocal send_delta_called
|
||||
send_delta_called = True
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"streamed": _StreamedChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# _streamed means message was already sent via send_delta, so skip send
|
||||
msg = OutboundMessage(
|
||||
channel="streamed", chat_id="123", content="test",
|
||||
metadata={"_streamed": True}
|
||||
)
|
||||
await mgr._send_with_retry(mgr.channels["streamed"], msg)
|
||||
|
||||
assert send_called is False
|
||||
assert send_delta_called is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_propagates_cancelled_error():
|
||||
"""_send_with_retry should re-raise CancelledError for graceful shutdown."""
|
||||
class _CancellingChannel(BaseChannel):
|
||||
name = "cancelling"
|
||||
display_name = "Cancelling"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
raise asyncio.CancelledError("simulated cancellation")
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"cancelling": _CancellingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="cancelling", chat_id="123", content="test")
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await mgr._send_with_retry(mgr.channels["cancelling"], msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_retry_propagates_cancelled_error_during_sleep():
|
||||
"""_send_with_retry should re-raise CancelledError during sleep."""
|
||||
call_count = 0
|
||||
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError("simulated failure")
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(send_max_retries=3),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
msg = OutboundMessage(channel="failing", chat_id="123", content="test")
|
||||
|
||||
# Mock sleep to raise CancelledError
|
||||
async def cancel_during_sleep(_):
|
||||
raise asyncio.CancelledError("cancelled during sleep")
|
||||
|
||||
with patch("nanobot.channels.manager.asyncio.sleep", side_effect=cancel_during_sleep):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await mgr._send_with_retry(mgr.channels["failing"], msg)
|
||||
|
||||
# Should have attempted once before sleep was cancelled
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelManager - lifecycle and getters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _ChannelWithAllowFrom(BaseChannel):
|
||||
"""Channel with configurable allow_from."""
|
||||
name = "withallow"
|
||||
display_name = "With Allow"
|
||||
|
||||
def __init__(self, config, bus, allow_from):
|
||||
super().__init__(config, bus)
|
||||
self.config.allow_from = allow_from
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _StartableChannel(BaseChannel):
|
||||
"""Channel that tracks start/stop calls."""
|
||||
name = "startable"
|
||||
display_name = "Startable"
|
||||
|
||||
def __init__(self, config, bus):
|
||||
super().__init__(config, bus)
|
||||
self.started = False
|
||||
self.stopped = False
|
||||
|
||||
async def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.stopped = True
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_allow_from_raises_on_empty_list():
|
||||
"""_validate_allow_from should raise SystemExit when allow_from is empty list."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
mgr._validate_allow_from()
|
||||
|
||||
assert "empty allowFrom" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_allow_from_passes_with_asterisk():
|
||||
"""_validate_allow_from should not raise when allow_from contains '*'."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, ["*"])}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# Should not raise
|
||||
mgr._validate_allow_from()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_channel_returns_channel_if_exists():
|
||||
"""get_channel should return the channel if it exists."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"telegram": _StartableChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
assert mgr.get_channel("telegram") is not None
|
||||
assert mgr.get_channel("nonexistent") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_status_returns_running_state():
|
||||
"""get_status should return enabled and running state for each channel."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
ch = _StartableChannel(fake_config, mgr.bus)
|
||||
mgr.channels = {"startable": ch}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
status = mgr.get_status()
|
||||
|
||||
assert status["startable"]["enabled"] is True
|
||||
assert status["startable"]["running"] is False # Not started yet
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enabled_channels_returns_channel_names():
|
||||
"""enabled_channels should return list of enabled channel names."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {
|
||||
"telegram": _StartableChannel(fake_config, mgr.bus),
|
||||
"slack": _StartableChannel(fake_config, mgr.bus),
|
||||
}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
enabled = mgr.enabled_channels
|
||||
|
||||
assert "telegram" in enabled
|
||||
assert "slack" in enabled
|
||||
assert len(enabled) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_cancels_dispatcher_and_stops_channels():
|
||||
"""stop_all should cancel the dispatch task and stop all channels."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
|
||||
ch = _StartableChannel(fake_config, mgr.bus)
|
||||
mgr.channels = {"startable": ch}
|
||||
|
||||
# Create a real cancelled task
|
||||
async def dummy_task():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
dispatch_task = asyncio.create_task(dummy_task())
|
||||
mgr._dispatch_task = dispatch_task
|
||||
|
||||
await mgr.stop_all()
|
||||
|
||||
# Task should be cancelled
|
||||
assert dispatch_task.cancelled()
|
||||
# Channel should be stopped
|
||||
assert ch.stopped is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_channel_logs_error_on_failure():
|
||||
"""_start_channel should log error when channel start fails."""
|
||||
class _FailingChannel(BaseChannel):
|
||||
name = "failing"
|
||||
display_name = "Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
raise RuntimeError("connection failed")
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
ch = _FailingChannel(fake_config, mgr.bus)
|
||||
|
||||
# Should not raise, just log error
|
||||
await mgr._start_channel("failing", ch)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_handles_channel_exception():
|
||||
"""stop_all should handle exceptions when stopping channels gracefully."""
|
||||
class _StopFailingChannel(BaseChannel):
|
||||
name = "stopfailing"
|
||||
display_name = "Stop Failing"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
raise RuntimeError("stop failed")
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
pass
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {"stopfailing": _StopFailingChannel(fake_config, mgr.bus)}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# Should not raise even if channel.stop() raises
|
||||
await mgr.stop_all()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_all_no_channels_logs_warning():
|
||||
"""start_all should log warning when no channels are enabled."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
mgr.channels = {} # No channels
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# Should return early without creating dispatch task
|
||||
await mgr.start_all()
|
||||
|
||||
assert mgr._dispatch_task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_all_creates_dispatch_task():
|
||||
"""start_all should create the dispatch task when channels exist."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.bus = MessageBus()
|
||||
|
||||
ch = _StartableChannel(fake_config, mgr.bus)
|
||||
mgr.channels = {"startable": ch}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
# Cancel immediately after start to avoid running forever
|
||||
async def cancel_after_start():
|
||||
await asyncio.sleep(0.01)
|
||||
if mgr._dispatch_task:
|
||||
mgr._dispatch_task.cancel()
|
||||
|
||||
cancel_task = asyncio.create_task(cancel_after_start())
|
||||
|
||||
try:
|
||||
await mgr.start_all()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
cancel_task.cancel()
|
||||
try:
|
||||
await cancel_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Dispatch task should have been created
|
||||
assert mgr._dispatch_task is not None
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional dingtalk dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import dingtalk
|
||||
DINGTALK_AVAILABLE = getattr(dingtalk, "DINGTALK_AVAILABLE", False)
|
||||
except ImportError:
|
||||
DINGTALK_AVAILABLE = False
|
||||
|
||||
if not DINGTALK_AVAILABLE:
|
||||
pytest.skip("DingTalk dependencies not installed (dingtalk-stream)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
import nanobot.channels.dingtalk as dingtalk_module
|
||||
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
|
||||
from nanobot.channels.dingtalk import DingTalkConfig
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
|
||||
self.status_code = status_code
|
||||
self._json_body = json_body or {}
|
||||
self.text = "{}"
|
||||
self.content = b""
|
||||
self.headers = {"content-type": "application/json"}
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._json_body
|
||||
|
||||
|
||||
class _FakeHttp:
|
||||
def __init__(self, responses: list[_FakeResponse] | None = None) -> None:
|
||||
self.calls: list[dict] = []
|
||||
self._responses = list(responses) if responses else []
|
||||
|
||||
def _next_response(self) -> _FakeResponse:
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return _FakeResponse()
|
||||
|
||||
async def post(self, url: str, json=None, headers=None, **kwargs):
|
||||
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
|
||||
return self._next_response()
|
||||
|
||||
async def get(self, url: str, **kwargs):
|
||||
self.calls.append({"method": "GET", "url": url})
|
||||
return self._next_response()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(config, bus)
|
||||
|
||||
await channel._on_message(
|
||||
"hello",
|
||||
sender_id="user1",
|
||||
sender_name="Alice",
|
||||
conversation_type="2",
|
||||
conversation_id="conv123",
|
||||
)
|
||||
|
||||
msg = await bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
assert msg.metadata["conversation_type"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_send_uses_group_messages_api() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
channel._http = _FakeHttp()
|
||||
|
||||
ok = await channel._send_batch_message(
|
||||
"token",
|
||||
"group:conv123",
|
||||
"sampleMarkdown",
|
||||
{"text": "hello", "title": "Nanobot Reply"},
|
||||
)
|
||||
|
||||
assert ok is True
|
||||
call = channel._http.calls[0]
|
||||
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
assert call["json"]["openConversationId"] == "conv123"
|
||||
assert call["json"]["msgKey"] == "sampleMarkdown"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||
bus,
|
||||
)
|
||||
handler = NanobotDingTalkHandler(channel)
|
||||
|
||||
class _FakeChatbotMessage:
|
||||
text = None
|
||||
extensions = {"content": {"recognition": "voice transcript"}}
|
||||
sender_staff_id = "user1"
|
||||
sender_id = "fallback-user"
|
||||
sender_nick = "Alice"
|
||||
message_type = "audio"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(_data):
|
||||
return _FakeChatbotMessage()
|
||||
|
||||
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
|
||||
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||
|
||||
status, body = await handler.process(
|
||||
SimpleNamespace(
|
||||
data={
|
||||
"conversationType": "2",
|
||||
"conversationId": "conv123",
|
||||
"text": {"content": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*list(channel._background_tasks))
|
||||
msg = await bus.consume_inbound()
|
||||
|
||||
assert (status, body) == ("OK", "OK")
|
||||
assert msg.content == "voice transcript"
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_processes_file_message(monkeypatch) -> None:
|
||||
"""Test that file messages are handled and forwarded with downloaded path."""
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||
bus,
|
||||
)
|
||||
handler = NanobotDingTalkHandler(channel)
|
||||
|
||||
class _FakeFileChatbotMessage:
|
||||
text = None
|
||||
extensions = {}
|
||||
image_content = None
|
||||
rich_text_content = None
|
||||
sender_staff_id = "user1"
|
||||
sender_id = "fallback-user"
|
||||
sender_nick = "Alice"
|
||||
message_type = "file"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(_data):
|
||||
return _FakeFileChatbotMessage()
|
||||
|
||||
async def fake_download(download_code, filename, sender_id):
|
||||
return f"/tmp/nanobot_dingtalk/{sender_id}/{filename}"
|
||||
|
||||
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeFileChatbotMessage)
|
||||
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||
monkeypatch.setattr(channel, "_download_dingtalk_file", fake_download)
|
||||
|
||||
status, body = await handler.process(
|
||||
SimpleNamespace(
|
||||
data={
|
||||
"conversationType": "1",
|
||||
"content": {"downloadCode": "abc123", "fileName": "report.xlsx"},
|
||||
"text": {"content": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*list(channel._background_tasks))
|
||||
msg = await bus.consume_inbound()
|
||||
|
||||
assert (status, body) == ("OK", "OK")
|
||||
assert "[File]" in msg.content
|
||||
assert "/tmp/nanobot_dingtalk/user1/report.xlsx" in msg.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
|
||||
"""Test the two-step file download flow (get URL then download content)."""
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
# Mock access token
|
||||
async def fake_get_token():
|
||||
return "test-token"
|
||||
|
||||
monkeypatch.setattr(channel, "_get_access_token", fake_get_token)
|
||||
|
||||
# Mock HTTP: first POST returns downloadUrl, then GET returns file bytes
|
||||
file_content = b"fake file content"
|
||||
channel._http = _FakeHttp(responses=[
|
||||
_FakeResponse(200, {"downloadUrl": "https://example.com/tmpfile"}),
|
||||
_FakeResponse(200),
|
||||
])
|
||||
channel._http._responses[1].content = file_content
|
||||
|
||||
# Redirect media dir to tmp_path
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.paths.get_media_dir",
|
||||
lambda channel_name=None: tmp_path / channel_name if channel_name else tmp_path,
|
||||
)
|
||||
|
||||
result = await channel._download_dingtalk_file("code123", "test.xlsx", "user1")
|
||||
|
||||
assert result is not None
|
||||
assert result.endswith("test.xlsx")
|
||||
assert (tmp_path / "dingtalk" / "user1" / "test.xlsx").read_bytes() == file_content
|
||||
|
||||
# Verify API calls
|
||||
assert channel._http.calls[0]["method"] == "POST"
|
||||
assert "messageFiles/download" in channel._http.calls[0]["url"]
|
||||
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
|
||||
assert channel._http.calls[1]["method"] == "GET"
|
||||
@@ -0,0 +1,652 @@
|
||||
from email.message import EmailMessage
|
||||
from datetime import date
|
||||
import imaplib
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.email import EmailChannel
|
||||
from nanobot.channels.email import EmailConfig
|
||||
|
||||
|
||||
def _make_config(**overrides) -> EmailConfig:
|
||||
defaults = dict(
|
||||
enabled=True,
|
||||
consent_granted=True,
|
||||
imap_host="imap.example.com",
|
||||
imap_port=993,
|
||||
imap_username="bot@example.com",
|
||||
imap_password="secret",
|
||||
smtp_host="smtp.example.com",
|
||||
smtp_port=587,
|
||||
smtp_username="bot@example.com",
|
||||
smtp_password="secret",
|
||||
mark_seen=True,
|
||||
# Disable auth verification by default so existing tests are unaffected
|
||||
verify_dkim=False,
|
||||
verify_spf=False,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return EmailConfig(**defaults)
|
||||
|
||||
|
||||
def _make_raw_email(
|
||||
from_addr: str = "alice@example.com",
|
||||
subject: str = "Hello",
|
||||
body: str = "This is the body.",
|
||||
auth_results: str | None = None,
|
||||
) -> bytes:
|
||||
msg = EmailMessage()
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = "bot@example.com"
|
||||
msg["Subject"] = subject
|
||||
msg["Message-ID"] = "<m1@example.com>"
|
||||
if auth_results:
|
||||
msg["Authentication-Results"] = auth_results
|
||||
msg.set_content(body)
|
||||
return msg.as_bytes()
|
||||
|
||||
|
||||
def test_fetch_new_messages_parses_unseen_and_marks_seen(monkeypatch) -> None:
|
||||
raw = _make_raw_email(subject="Invoice", body="Please pay")
|
||||
|
||||
class FakeIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
fake = FakeIMAP()
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["sender"] == "alice@example.com"
|
||||
assert items[0]["subject"] == "Invoice"
|
||||
assert "Please pay" in items[0]["content"]
|
||||
assert fake.store_calls == [(b"1", "+FLAGS", "\\Seen")]
|
||||
|
||||
# Same UID should be deduped in-process.
|
||||
items_again = channel._fetch_new_messages()
|
||||
assert items_again == []
|
||||
|
||||
|
||||
def test_fetch_new_messages_retries_once_when_imap_connection_goes_stale(monkeypatch) -> None:
|
||||
raw = _make_raw_email(subject="Invoice", body="Please pay")
|
||||
fail_once = {"pending": True}
|
||||
|
||||
class FlakyIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
self.search_calls = 0
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
self.search_calls += 1
|
||||
if fail_once["pending"]:
|
||||
fail_once["pending"] = False
|
||||
raise imaplib.IMAP4.abort("socket error")
|
||||
return "OK", [b"1"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
fake_instances: list[FlakyIMAP] = []
|
||||
|
||||
def _factory(_host: str, _port: int):
|
||||
instance = FlakyIMAP()
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", _factory)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert len(fake_instances) == 2
|
||||
assert fake_instances[0].search_calls == 1
|
||||
assert fake_instances[1].search_calls == 1
|
||||
|
||||
|
||||
def test_fetch_new_messages_keeps_messages_collected_before_stale_retry(monkeypatch) -> None:
|
||||
raw_first = _make_raw_email(subject="First", body="First body")
|
||||
raw_second = _make_raw_email(subject="Second", body="Second body")
|
||||
mailbox_state = {
|
||||
b"1": {"uid": b"123", "raw": raw_first, "seen": False},
|
||||
b"2": {"uid": b"124", "raw": raw_second, "seen": False},
|
||||
}
|
||||
fail_once = {"pending": True}
|
||||
|
||||
class FlakyIMAP:
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"2"]
|
||||
|
||||
def search(self, *_args):
|
||||
unseen_ids = [imap_id for imap_id, item in mailbox_state.items() if not item["seen"]]
|
||||
return "OK", [b" ".join(unseen_ids)]
|
||||
|
||||
def fetch(self, imap_id: bytes, _parts: str):
|
||||
if imap_id == b"2" and fail_once["pending"]:
|
||||
fail_once["pending"] = False
|
||||
raise imaplib.IMAP4.abort("socket error")
|
||||
item = mailbox_state[imap_id]
|
||||
header = b"%s (UID %s BODY[] {200})" % (imap_id, item["uid"])
|
||||
return "OK", [(header, item["raw"]), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, _op: str, _flags: str):
|
||||
mailbox_state[imap_id]["seen"] = True
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: FlakyIMAP())
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert [item["subject"] for item in items] == ["First", "Second"]
|
||||
|
||||
|
||||
def test_fetch_new_messages_skips_missing_mailbox(monkeypatch) -> None:
|
||||
class MissingMailboxIMAP:
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
raise imaplib.IMAP4.error("Mailbox doesn't exist")
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.email.imaplib.IMAP4_SSL",
|
||||
lambda _h, _p: MissingMailboxIMAP(),
|
||||
)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
|
||||
assert channel._fetch_new_messages() == []
|
||||
|
||||
|
||||
def test_extract_text_body_falls_back_to_html() -> None:
|
||||
msg = EmailMessage()
|
||||
msg["From"] = "alice@example.com"
|
||||
msg["To"] = "bot@example.com"
|
||||
msg["Subject"] = "HTML only"
|
||||
msg.add_alternative("<p>Hello<br>world</p>", subtype="html")
|
||||
|
||||
text = EmailChannel._extract_text_body(msg)
|
||||
assert "Hello" in text
|
||||
assert "world" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_immediately_without_consent(monkeypatch) -> None:
|
||||
cfg = _make_config()
|
||||
cfg.consent_granted = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
called = {"fetch": False}
|
||||
|
||||
def _fake_fetch():
|
||||
called["fetch"] = True
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(channel, "_fetch_new_messages", _fake_fetch)
|
||||
await channel.start()
|
||||
assert channel.is_running is False
|
||||
assert called["fetch"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.timeout = timeout
|
||||
self.started_tls = False
|
||||
self.logged_in = False
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
self.started_tls = True
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
self.logged_in = True
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
channel._last_subject_by_chat["alice@example.com"] = "Invoice #42"
|
||||
channel._last_message_id_by_chat["alice@example.com"] = "<m1@example.com>"
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Acknowledged.",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_instances) == 1
|
||||
smtp = fake_instances[0]
|
||||
assert smtp.started_tls is True
|
||||
assert smtp.logged_in is True
|
||||
assert len(smtp.sent_messages) == 1
|
||||
sent = smtp.sent_messages[0]
|
||||
assert sent["Subject"] == "Re: Invoice #42"
|
||||
assert sent["To"] == "alice@example.com"
|
||||
assert sent["In-Reply-To"] == "<m1@example.com>"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# Mark alice as someone who sent us an email (making this a "reply")
|
||||
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
|
||||
|
||||
# Reply should be skipped (auto_reply_enabled=False)
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Should not send.",
|
||||
)
|
||||
)
|
||||
assert fake_instances == []
|
||||
|
||||
# Reply with force_send=True should be sent
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Force send.",
|
||||
metadata={"force_send": True},
|
||||
)
|
||||
)
|
||||
assert len(fake_instances) == 1
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# bob@example.com has never sent us an email (proactive send)
|
||||
# This should be sent even with auto_reply_enabled=False
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="bob@example.com",
|
||||
content="Hello, this is a proactive email.",
|
||||
)
|
||||
)
|
||||
assert len(fake_instances) == 1
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
sent = fake_instances[0].sent_messages[0]
|
||||
assert sent["To"] == "bob@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
called = {"smtp": False}
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
called["smtp"] = True
|
||||
return FakeSMTP(host, port, timeout=timeout)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.consent_granted = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Should not send.",
|
||||
metadata={"force_send": True},
|
||||
)
|
||||
)
|
||||
assert called["smtp"] is False
|
||||
|
||||
|
||||
def test_fetch_messages_between_dates_uses_imap_since_before_without_mark_seen(monkeypatch) -> None:
|
||||
raw = _make_raw_email(subject="Status", body="Yesterday update")
|
||||
|
||||
class FakeIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.search_args = None
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
self.search_args = _args
|
||||
return "OK", [b"5"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"5 (UID 999 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
fake = FakeIMAP()
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel.fetch_messages_between_dates(
|
||||
start_date=date(2026, 2, 6),
|
||||
end_date=date(2026, 2, 7),
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["subject"] == "Status"
|
||||
# search(None, "SINCE", "06-Feb-2026", "BEFORE", "07-Feb-2026")
|
||||
assert fake.search_args is not None
|
||||
assert fake.search_args[1:] == ("SINCE", "06-Feb-2026", "BEFORE", "07-Feb-2026")
|
||||
assert fake.store_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security: Anti-spoofing tests for Authentication-Results verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_fake_imap(raw: bytes):
|
||||
"""Return a FakeIMAP class pre-loaded with the given raw email."""
|
||||
class FakeIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"1 (UID 500 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
return FakeIMAP()
|
||||
|
||||
|
||||
def test_spoofed_email_rejected_when_verify_enabled(monkeypatch) -> None:
|
||||
"""An email without Authentication-Results should be rejected when verify_dkim=True."""
|
||||
raw = _make_raw_email(subject="Spoofed", body="Malicious payload")
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=True, verify_spf=True)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 0, "Spoofed email without auth headers should be rejected"
|
||||
|
||||
|
||||
def test_email_with_valid_auth_results_accepted(monkeypatch) -> None:
|
||||
"""An email with spf=pass and dkim=pass should be accepted."""
|
||||
raw = _make_raw_email(
|
||||
subject="Legit",
|
||||
body="Hello from verified sender",
|
||||
auth_results="mx.example.com; spf=pass smtp.mailfrom=alice@example.com; dkim=pass header.d=example.com",
|
||||
)
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=True, verify_spf=True)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["sender"] == "alice@example.com"
|
||||
assert items[0]["subject"] == "Legit"
|
||||
|
||||
|
||||
def test_email_with_partial_auth_rejected(monkeypatch) -> None:
|
||||
"""An email with only spf=pass but no dkim=pass should be rejected when verify_dkim=True."""
|
||||
raw = _make_raw_email(
|
||||
subject="Partial",
|
||||
body="Only SPF passes",
|
||||
auth_results="mx.example.com; spf=pass smtp.mailfrom=alice@example.com; dkim=fail",
|
||||
)
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=True, verify_spf=True)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 0, "Email with dkim=fail should be rejected"
|
||||
|
||||
|
||||
def test_backward_compat_verify_disabled(monkeypatch) -> None:
|
||||
"""When verify_dkim=False and verify_spf=False, emails without auth headers are accepted."""
|
||||
raw = _make_raw_email(subject="NoAuth", body="No auth headers present")
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=False, verify_spf=False)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1, "With verification disabled, emails should be accepted as before"
|
||||
|
||||
|
||||
def test_email_content_tagged_with_email_context(monkeypatch) -> None:
|
||||
"""Email content should be prefixed with [EMAIL-CONTEXT] for LLM isolation."""
|
||||
raw = _make_raw_email(subject="Tagged", body="Check the tag")
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=False, verify_spf=False)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["content"].startswith("[EMAIL-CONTEXT]"), (
|
||||
"Email content must be tagged with [EMAIL-CONTEXT]"
|
||||
)
|
||||
|
||||
|
||||
def test_check_authentication_results_method() -> None:
|
||||
"""Unit test for the _check_authentication_results static method."""
|
||||
from email.parser import BytesParser
|
||||
from email import policy
|
||||
|
||||
# No Authentication-Results header
|
||||
msg_no_auth = EmailMessage()
|
||||
msg_no_auth["From"] = "alice@example.com"
|
||||
msg_no_auth.set_content("test")
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(msg_no_auth.as_bytes())
|
||||
spf, dkim = EmailChannel._check_authentication_results(parsed)
|
||||
assert spf is False
|
||||
assert dkim is False
|
||||
|
||||
# Both pass
|
||||
msg_both = EmailMessage()
|
||||
msg_both["From"] = "alice@example.com"
|
||||
msg_both["Authentication-Results"] = (
|
||||
"mx.google.com; spf=pass smtp.mailfrom=example.com; dkim=pass header.d=example.com"
|
||||
)
|
||||
msg_both.set_content("test")
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(msg_both.as_bytes())
|
||||
spf, dkim = EmailChannel._check_authentication_results(parsed)
|
||||
assert spf is True
|
||||
assert dkim is True
|
||||
|
||||
# SPF pass, DKIM fail
|
||||
msg_spf_only = EmailMessage()
|
||||
msg_spf_only["From"] = "alice@example.com"
|
||||
msg_spf_only["Authentication-Results"] = (
|
||||
"mx.google.com; spf=pass smtp.mailfrom=example.com; dkim=fail"
|
||||
)
|
||||
msg_spf_only.set_content("test")
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(msg_spf_only.as_bytes())
|
||||
spf, dkim = EmailChannel._check_authentication_results(parsed)
|
||||
assert spf is True
|
||||
assert dkim is False
|
||||
|
||||
# DKIM pass, SPF fail
|
||||
msg_dkim_only = EmailMessage()
|
||||
msg_dkim_only["From"] = "alice@example.com"
|
||||
msg_dkim_only["Authentication-Results"] = (
|
||||
"mx.google.com; spf=fail smtp.mailfrom=example.com; dkim=pass header.d=example.com"
|
||||
)
|
||||
msg_dkim_only.set_content("test")
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(msg_dkim_only.as_bytes())
|
||||
spf, dkim = EmailChannel._check_authentication_results(parsed)
|
||||
assert spf is False
|
||||
assert dkim is True
|
||||
@@ -0,0 +1,68 @@
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
import pytest
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
def test_parse_md_table_strips_markdown_formatting_in_headers_and_cells() -> None:
|
||||
table = FeishuChannel._parse_md_table(
|
||||
"""
|
||||
| **Name** | __Status__ | *Notes* | ~~State~~ |
|
||||
| --- | --- | --- | --- |
|
||||
| **Alice** | __Ready__ | *Fast* | ~~Old~~ |
|
||||
"""
|
||||
)
|
||||
|
||||
assert table is not None
|
||||
assert [col["display_name"] for col in table["columns"]] == [
|
||||
"Name",
|
||||
"Status",
|
||||
"Notes",
|
||||
"State",
|
||||
]
|
||||
assert table["rows"] == [
|
||||
{"c0": "Alice", "c1": "Ready", "c2": "Fast", "c3": "Old"}
|
||||
]
|
||||
|
||||
|
||||
def test_split_headings_strips_embedded_markdown_before_bolding() -> None:
|
||||
channel = FeishuChannel.__new__(FeishuChannel)
|
||||
|
||||
elements = channel._split_headings("# **Important** *status* ~~update~~")
|
||||
|
||||
assert elements == [
|
||||
{
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": "**Important status update**",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_split_headings_keeps_markdown_body_and_code_blocks_intact() -> None:
|
||||
channel = FeishuChannel.__new__(FeishuChannel)
|
||||
|
||||
elements = channel._split_headings(
|
||||
"# **Heading**\n\nBody with **bold** text.\n\n```python\nprint('hi')\n```"
|
||||
)
|
||||
|
||||
assert elements[0] == {
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": "**Heading**",
|
||||
},
|
||||
}
|
||||
assert elements[1]["tag"] == "markdown"
|
||||
assert "Body with **bold** text." in elements[1]["content"]
|
||||
assert "```python\nprint('hi')\n```" in elements[1]["content"]
|
||||
@@ -0,0 +1,76 @@
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
import pytest
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
|
||||
|
||||
|
||||
def test_extract_post_content_supports_post_wrapper_shape() -> None:
|
||||
payload = {
|
||||
"post": {
|
||||
"zh_cn": {
|
||||
"title": "日报",
|
||||
"content": [
|
||||
[
|
||||
{"tag": "text", "text": "完成"},
|
||||
{"tag": "img", "image_key": "img_1"},
|
||||
]
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text, image_keys = _extract_post_content(payload)
|
||||
|
||||
assert text == "日报 完成"
|
||||
assert image_keys == ["img_1"]
|
||||
|
||||
|
||||
def test_extract_post_content_keeps_direct_shape_behavior() -> None:
|
||||
payload = {
|
||||
"title": "Daily",
|
||||
"content": [
|
||||
[
|
||||
{"tag": "text", "text": "report"},
|
||||
{"tag": "img", "image_key": "img_a"},
|
||||
{"tag": "img", "image_key": "img_b"},
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
text, image_keys = _extract_post_content(payload)
|
||||
|
||||
assert text == "Daily report"
|
||||
assert image_keys == ["img_a", "img_b"]
|
||||
|
||||
|
||||
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
|
||||
class Builder:
|
||||
pass
|
||||
|
||||
builder = Builder()
|
||||
same = FeishuChannel._register_optional_event(builder, "missing", object())
|
||||
assert same is builder
|
||||
|
||||
|
||||
def test_register_optional_event_calls_supported_method() -> None:
|
||||
called = []
|
||||
|
||||
class Builder:
|
||||
def register_event(self, handler):
|
||||
called.append(handler)
|
||||
return self
|
||||
|
||||
builder = Builder()
|
||||
handler = object()
|
||||
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
|
||||
|
||||
assert same is builder
|
||||
assert called == [handler]
|
||||
@@ -0,0 +1,445 @@
|
||||
"""Tests for Feishu message reply (quote) feature."""
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel:
|
||||
config = FeishuConfig(
|
||||
enabled=True,
|
||||
app_id="cli_test",
|
||||
app_secret="secret",
|
||||
allow_from=["*"],
|
||||
reply_to_message=reply_to_message,
|
||||
)
|
||||
channel = FeishuChannel(config, MessageBus())
|
||||
channel._client = MagicMock()
|
||||
# _loop is only used by the WebSocket thread bridge; not needed for unit tests
|
||||
channel._loop = None
|
||||
return channel
|
||||
|
||||
|
||||
def _make_feishu_event(
|
||||
*,
|
||||
message_id: str = "om_001",
|
||||
chat_id: str = "oc_abc",
|
||||
chat_type: str = "p2p",
|
||||
msg_type: str = "text",
|
||||
content: str = '{"text": "hello"}',
|
||||
sender_open_id: str = "ou_alice",
|
||||
parent_id: str | None = None,
|
||||
root_id: str | None = None,
|
||||
):
|
||||
message = SimpleNamespace(
|
||||
message_id=message_id,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
message_type=msg_type,
|
||||
content=content,
|
||||
parent_id=parent_id,
|
||||
root_id=root_id,
|
||||
mentions=[],
|
||||
)
|
||||
sender = SimpleNamespace(
|
||||
sender_type="user",
|
||||
sender_id=SimpleNamespace(open_id=sender_open_id),
|
||||
)
|
||||
return SimpleNamespace(event=SimpleNamespace(message=message, sender=sender))
|
||||
|
||||
|
||||
def _make_get_message_response(text: str, msg_type: str = "text", success: bool = True):
|
||||
"""Build a fake im.v1.message.get response object."""
|
||||
body = SimpleNamespace(content=json.dumps({"text": text}))
|
||||
item = SimpleNamespace(msg_type=msg_type, body=body)
|
||||
data = SimpleNamespace(items=[item])
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = success
|
||||
resp.data = data
|
||||
resp.code = 0
|
||||
resp.msg = "ok"
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_feishu_config_reply_to_message_defaults_false() -> None:
|
||||
assert FeishuConfig().reply_to_message is False
|
||||
|
||||
|
||||
def test_feishu_config_reply_to_message_can_be_enabled() -> None:
|
||||
config = FeishuConfig(reply_to_message=True)
|
||||
assert config.reply_to_message is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_message_content_sync tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_message_content_sync_returns_reply_prefix() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._client.im.v1.message.get.return_value = _make_get_message_response("what time is it?")
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result == "[Reply to: what time is it?]"
|
||||
|
||||
|
||||
def test_get_message_content_sync_truncates_long_text() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
long_text = "x" * (FeishuChannel._REPLY_CONTEXT_MAX_LEN + 50)
|
||||
channel._client.im.v1.message.get.return_value = _make_get_message_response(long_text)
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result is not None
|
||||
assert result.endswith("...]")
|
||||
inner = result[len("[Reply to: ") : -1]
|
||||
assert len(inner) == FeishuChannel._REPLY_CONTEXT_MAX_LEN + len("...")
|
||||
|
||||
|
||||
def test_get_message_content_sync_returns_none_on_api_failure() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 230002
|
||||
resp.msg = "bot not in group"
|
||||
channel._client.im.v1.message.get.return_value = resp
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_message_content_sync_returns_none_for_non_text_type() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
body = SimpleNamespace(content=json.dumps({"image_key": "img_1"}))
|
||||
item = SimpleNamespace(msg_type="image", body=body)
|
||||
data = SimpleNamespace(items=[item])
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = True
|
||||
resp.data = data
|
||||
channel._client.im.v1.message.get.return_value = resp
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_message_content_sync_returns_none_when_empty_text() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._client.im.v1.message.get.return_value = _make_get_message_response(" ")
|
||||
|
||||
result = channel._get_message_content_sync("om_parent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reply_message_sync tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_reply_message_sync_returns_true_on_success() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = resp
|
||||
|
||||
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
|
||||
|
||||
assert ok is True
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
|
||||
|
||||
def test_reply_message_sync_returns_false_on_api_error() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 400
|
||||
resp.msg = "bad request"
|
||||
resp.get_log_id.return_value = "log_x"
|
||||
channel._client.im.v1.message.reply.return_value = resp
|
||||
|
||||
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
|
||||
|
||||
assert ok is False
|
||||
|
||||
|
||||
def test_reply_message_sync_returns_false_on_exception() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._client.im.v1.message.reply.side_effect = RuntimeError("network error")
|
||||
|
||||
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
|
||||
|
||||
assert ok is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("filename", "expected_msg_type"),
|
||||
[
|
||||
("voice.opus", "audio"),
|
||||
("clip.mp4", "video"),
|
||||
("report.pdf", "file"),
|
||||
],
|
||||
)
|
||||
async def test_send_uses_expected_feishu_msg_type_for_uploaded_files(
|
||||
tmp_path: Path, filename: str, expected_msg_type: str
|
||||
) -> None:
|
||||
channel = _make_feishu_channel()
|
||||
file_path = tmp_path / filename
|
||||
file_path.write_bytes(b"demo")
|
||||
|
||||
send_calls: list[tuple[str, str, str, str]] = []
|
||||
|
||||
def _record_send(receive_id_type: str, receive_id: str, msg_type: str, content: str) -> None:
|
||||
send_calls.append((receive_id_type, receive_id, msg_type, content))
|
||||
|
||||
with patch.object(channel, "_upload_file_sync", return_value="file-key"), patch.object(
|
||||
channel, "_send_message_sync", side_effect=_record_send
|
||||
):
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_test",
|
||||
content="",
|
||||
media=[str(file_path)],
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(send_calls) == 1
|
||||
receive_id_type, receive_id, msg_type, content = send_calls[0]
|
||||
assert receive_id_type == "chat_id"
|
||||
assert receive_id == "oc_test"
|
||||
assert msg_type == expected_msg_type
|
||||
assert json.loads(content) == {"file_key": "file-key"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send() — reply routing tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_reply_api_when_configured() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
channel._client.im.v1.message.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_create_api_when_reply_disabled() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=False)
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
channel._client.im.v1.message.reply.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_create_api_when_no_message_id() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
channel._client.im.v1.message.reply.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_reply_for_progress_messages() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="thinking...",
|
||||
metadata={"message_id": "om_001", "_progress": True},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
channel._client.im.v1.message.reply.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_fallback_to_create_when_reply_fails() -> None:
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = False
|
||||
reply_resp.code = 400
|
||||
reply_resp.msg = "error"
|
||||
reply_resp.get_log_id.return_value = "log_x"
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
# reply attempted first, then falls back to create
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _on_message — parent_id / root_id metadata tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_captures_parent_and_root_id_in_metadata() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._processed_message_ids.clear()
|
||||
channel._client.im.v1.message.react.return_value = MagicMock(success=lambda: True)
|
||||
|
||||
captured = []
|
||||
|
||||
async def _capture(**kwargs):
|
||||
captured.append(kwargs)
|
||||
|
||||
channel._handle_message = _capture
|
||||
|
||||
with patch.object(channel, "_add_reaction", return_value=None):
|
||||
await channel._on_message(
|
||||
_make_feishu_event(
|
||||
parent_id="om_parent",
|
||||
root_id="om_root",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(captured) == 1
|
||||
meta = captured[0]["metadata"]
|
||||
assert meta["parent_id"] == "om_parent"
|
||||
assert meta["root_id"] == "om_root"
|
||||
assert meta["message_id"] == "om_001"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_parent_and_root_id_none_when_absent() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._processed_message_ids.clear()
|
||||
|
||||
captured = []
|
||||
|
||||
async def _capture(**kwargs):
|
||||
captured.append(kwargs)
|
||||
|
||||
channel._handle_message = _capture
|
||||
|
||||
with patch.object(channel, "_add_reaction", return_value=None):
|
||||
await channel._on_message(_make_feishu_event())
|
||||
|
||||
assert len(captured) == 1
|
||||
meta = captured[0]["metadata"]
|
||||
assert meta["parent_id"] is None
|
||||
assert meta["root_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_prepends_reply_context_when_parent_id_present() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._processed_message_ids.clear()
|
||||
channel._client.im.v1.message.get.return_value = _make_get_message_response("original question")
|
||||
|
||||
captured = []
|
||||
|
||||
async def _capture(**kwargs):
|
||||
captured.append(kwargs)
|
||||
|
||||
channel._handle_message = _capture
|
||||
|
||||
with patch.object(channel, "_add_reaction", return_value=None):
|
||||
await channel._on_message(
|
||||
_make_feishu_event(
|
||||
content='{"text": "my answer"}',
|
||||
parent_id="om_parent",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(captured) == 1
|
||||
content = captured[0]["content"]
|
||||
assert content.startswith("[Reply to: original question]")
|
||||
assert "my answer" in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_no_extra_api_call_when_no_parent_id() -> None:
|
||||
channel = _make_feishu_channel()
|
||||
channel._processed_message_ids.clear()
|
||||
|
||||
captured = []
|
||||
|
||||
async def _capture(**kwargs):
|
||||
captured.append(kwargs)
|
||||
|
||||
channel._handle_message = _capture
|
||||
|
||||
with patch.object(channel, "_add_reaction", return_value=None):
|
||||
await channel._on_message(_make_feishu_event())
|
||||
|
||||
channel._client.im.v1.message.get.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
@@ -0,0 +1,258 @@
|
||||
"""Tests for Feishu streaming (send_delta) via CardKit streaming API."""
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig, _FeishuStreamBuf
|
||||
|
||||
|
||||
def _make_channel(streaming: bool = True) -> FeishuChannel:
|
||||
config = FeishuConfig(
|
||||
enabled=True,
|
||||
app_id="cli_test",
|
||||
app_secret="secret",
|
||||
allow_from=["*"],
|
||||
streaming=streaming,
|
||||
)
|
||||
ch = FeishuChannel(config, MessageBus())
|
||||
ch._client = MagicMock()
|
||||
ch._loop = None
|
||||
return ch
|
||||
|
||||
|
||||
def _mock_create_card_response(card_id: str = "card_stream_001"):
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = True
|
||||
resp.data = SimpleNamespace(card_id=card_id)
|
||||
return resp
|
||||
|
||||
|
||||
def _mock_send_response(message_id: str = "om_stream_001"):
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = True
|
||||
resp.data = SimpleNamespace(message_id=message_id)
|
||||
return resp
|
||||
|
||||
|
||||
def _mock_content_response(success: bool = True):
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = success
|
||||
resp.code = 0 if success else 99999
|
||||
resp.msg = "ok" if success else "error"
|
||||
return resp
|
||||
|
||||
|
||||
class TestFeishuStreamingConfig:
|
||||
def test_streaming_default_true(self):
|
||||
assert FeishuConfig().streaming is True
|
||||
|
||||
def test_supports_streaming_when_enabled(self):
|
||||
ch = _make_channel(streaming=True)
|
||||
assert ch.supports_streaming is True
|
||||
|
||||
def test_supports_streaming_disabled(self):
|
||||
ch = _make_channel(streaming=False)
|
||||
assert ch.supports_streaming is False
|
||||
|
||||
|
||||
class TestCreateStreamingCard:
|
||||
def test_returns_card_id_on_success(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.create.return_value = _mock_create_card_response("card_123")
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response()
|
||||
result = ch._create_streaming_card_sync("chat_id", "oc_chat1")
|
||||
assert result == "card_123"
|
||||
ch._client.cardkit.v1.card.create.assert_called_once()
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
def test_returns_none_on_failure(self):
|
||||
ch = _make_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 99999
|
||||
resp.msg = "error"
|
||||
ch._client.cardkit.v1.card.create.return_value = resp
|
||||
assert ch._create_streaming_card_sync("chat_id", "oc_chat1") is None
|
||||
|
||||
def test_returns_none_on_exception(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.create.side_effect = RuntimeError("network")
|
||||
assert ch._create_streaming_card_sync("chat_id", "oc_chat1") is None
|
||||
|
||||
def test_returns_none_when_card_send_fails(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.create.return_value = _mock_create_card_response("card_123")
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 99999
|
||||
resp.msg = "error"
|
||||
resp.get_log_id.return_value = "log1"
|
||||
ch._client.im.v1.message.create.return_value = resp
|
||||
assert ch._create_streaming_card_sync("chat_id", "oc_chat1") is None
|
||||
|
||||
|
||||
class TestCloseStreamingMode:
|
||||
def test_returns_true_on_success(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response(True)
|
||||
assert ch._close_streaming_mode_sync("card_1", 10) is True
|
||||
|
||||
def test_returns_false_on_failure(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response(False)
|
||||
assert ch._close_streaming_mode_sync("card_1", 10) is False
|
||||
|
||||
def test_returns_false_on_exception(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.settings.side_effect = RuntimeError("err")
|
||||
assert ch._close_streaming_mode_sync("card_1", 10) is False
|
||||
|
||||
|
||||
class TestStreamUpdateText:
|
||||
def test_returns_true_on_success(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(True)
|
||||
assert ch._stream_update_text_sync("card_1", "hello", 1) is True
|
||||
|
||||
def test_returns_false_on_failure(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(False)
|
||||
assert ch._stream_update_text_sync("card_1", "hello", 1) is False
|
||||
|
||||
def test_returns_false_on_exception(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card_element.content.side_effect = RuntimeError("err")
|
||||
assert ch._stream_update_text_sync("card_1", "hello", 1) is False
|
||||
|
||||
|
||||
class TestSendDelta:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_delta_creates_card_and_sends(self):
|
||||
ch = _make_channel()
|
||||
ch._client.cardkit.v1.card.create.return_value = _mock_create_card_response("card_new")
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response("om_new")
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", "Hello ")
|
||||
|
||||
assert "oc_chat1" in ch._stream_bufs
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert buf.text == "Hello "
|
||||
assert buf.card_id == "card_new"
|
||||
assert buf.sequence == 1
|
||||
ch._client.cardkit.v1.card.create.assert_called_once()
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_delta_within_interval_skips_update(self):
|
||||
ch = _make_channel()
|
||||
buf = _FeishuStreamBuf(text="Hello ", card_id="card_1", sequence=1, last_edit=time.monotonic())
|
||||
ch._stream_bufs["oc_chat1"] = buf
|
||||
|
||||
await ch.send_delta("oc_chat1", "world")
|
||||
|
||||
assert buf.text == "Hello world"
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delta_after_interval_updates_text(self):
|
||||
ch = _make_channel()
|
||||
buf = _FeishuStreamBuf(text="Hello ", card_id="card_1", sequence=1, last_edit=time.monotonic() - 1.0)
|
||||
ch._stream_bufs["oc_chat1"] = buf
|
||||
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
await ch.send_delta("oc_chat1", "world")
|
||||
|
||||
assert buf.text == "Hello world"
|
||||
assert buf.sequence == 2
|
||||
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_sends_final_update(self):
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Final content", card_id="card_1", sequence=3, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||
ch._client.cardkit.v1.card.settings.assert_called_once()
|
||||
settings_call = ch._client.cardkit.v1.card.settings.call_args[0][0]
|
||||
assert settings_call.body.sequence == 5 # after final content seq 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_fallback_when_no_card_id(self):
|
||||
"""If card creation failed, stream_end falls back to a plain card message."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Fallback content", card_id=None, sequence=0, last_edit=0.0,
|
||||
)
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response("om_fb")
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_without_buf_is_noop(self):
|
||||
ch = _make_channel()
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_delta_skips_send(self):
|
||||
ch = _make_channel()
|
||||
await ch.send_delta("oc_chat1", " ")
|
||||
|
||||
assert "oc_chat1" in ch._stream_bufs
|
||||
ch._client.cardkit.v1.card.create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_returns_early(self):
|
||||
ch = _make_channel()
|
||||
ch._client = None
|
||||
await ch.send_delta("oc_chat1", "text")
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequence_increments_correctly(self):
|
||||
ch = _make_channel()
|
||||
buf = _FeishuStreamBuf(text="a", card_id="card_1", sequence=5, last_edit=0.0)
|
||||
ch._stream_bufs["oc_chat1"] = buf
|
||||
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
await ch.send_delta("oc_chat1", "b")
|
||||
assert buf.sequence == 6
|
||||
|
||||
buf.last_edit = 0.0 # reset to bypass throttle
|
||||
await ch.send_delta("oc_chat1", "c")
|
||||
assert buf.sequence == 7
|
||||
|
||||
|
||||
class TestSendMessageReturnsId:
|
||||
def test_returns_message_id_on_success(self):
|
||||
ch = _make_channel()
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response("om_abc")
|
||||
result = ch._send_message_sync("chat_id", "oc_chat1", "text", '{"text":"hi"}')
|
||||
assert result == "om_abc"
|
||||
|
||||
def test_returns_none_on_failure(self):
|
||||
ch = _make_channel()
|
||||
resp = MagicMock()
|
||||
resp.success.return_value = False
|
||||
resp.code = 99999
|
||||
resp.msg = "error"
|
||||
resp.get_log_id.return_value = "log1"
|
||||
ch._client.im.v1.message.create.return_value = resp
|
||||
result = ch._send_message_sync("chat_id", "oc_chat1", "text", '{"text":"hi"}')
|
||||
assert result is None
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Tests for FeishuChannel._split_elements_by_table_limit.
|
||||
|
||||
Feishu cards reject messages that contain more than one table element
|
||||
(API error 11310: card table number over limit). The helper splits a flat
|
||||
list of card elements into groups so that each group contains at most one
|
||||
table, allowing nanobot to send multiple cards instead of failing.
|
||||
"""
|
||||
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
import pytest
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
def _md(text: str) -> dict:
|
||||
return {"tag": "markdown", "content": text}
|
||||
|
||||
|
||||
def _table() -> dict:
|
||||
return {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "v"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
|
||||
|
||||
split = FeishuChannel._split_elements_by_table_limit
|
||||
|
||||
|
||||
def test_empty_list_returns_single_empty_group() -> None:
|
||||
assert split([]) == [[]]
|
||||
|
||||
|
||||
def test_no_tables_returns_single_group() -> None:
|
||||
els = [_md("hello"), _md("world")]
|
||||
result = split(els)
|
||||
assert result == [els]
|
||||
|
||||
|
||||
def test_single_table_stays_in_one_group() -> None:
|
||||
els = [_md("intro"), _table(), _md("outro")]
|
||||
result = split(els)
|
||||
assert len(result) == 1
|
||||
assert result[0] == els
|
||||
|
||||
|
||||
def test_two_tables_split_into_two_groups() -> None:
|
||||
# Use different row values so the two tables are not equal
|
||||
t1 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "table-one"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
t2 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
|
||||
"rows": [{"c0": "table-two"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
els = [_md("before"), t1, _md("between"), t2, _md("after")]
|
||||
result = split(els)
|
||||
assert len(result) == 2
|
||||
# First group: text before table-1 + table-1
|
||||
assert t1 in result[0]
|
||||
assert t2 not in result[0]
|
||||
# Second group: text between tables + table-2 + text after
|
||||
assert t2 in result[1]
|
||||
assert t1 not in result[1]
|
||||
|
||||
|
||||
def test_three_tables_split_into_three_groups() -> None:
|
||||
tables = [
|
||||
{"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
|
||||
for i in range(3)
|
||||
]
|
||||
els = tables[:]
|
||||
result = split(els)
|
||||
assert len(result) == 3
|
||||
for i, group in enumerate(result):
|
||||
assert tables[i] in group
|
||||
|
||||
|
||||
def test_leading_markdown_stays_with_first_table() -> None:
|
||||
intro = _md("intro")
|
||||
t = _table()
|
||||
result = split([intro, t])
|
||||
assert len(result) == 1
|
||||
assert result[0] == [intro, t]
|
||||
|
||||
|
||||
def test_trailing_markdown_after_second_table() -> None:
|
||||
t1, t2 = _table(), _table()
|
||||
tail = _md("end")
|
||||
result = split([t1, t2, tail])
|
||||
assert len(result) == 2
|
||||
assert result[1] == [t2, tail]
|
||||
|
||||
|
||||
def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
|
||||
head = _md("head")
|
||||
t1, t2 = _table(), _table()
|
||||
result = split([head, t1, t2])
|
||||
# head + t1 in group 0; t2 in group 1
|
||||
assert result[0] == [head, t1]
|
||||
assert result[1] == [t2]
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Tests for FeishuChannel tool hint code block formatting."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
# Check optional Feishu dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import feishu
|
||||
FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False)
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
|
||||
if not FEISHU_AVAILABLE:
|
||||
pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feishu_channel():
|
||||
"""Create a FeishuChannel with mocked client."""
|
||||
config = MagicMock()
|
||||
config.app_id = "test_app_id"
|
||||
config.app_secret = "test_app_secret"
|
||||
config.encrypt_key = None
|
||||
config.verification_token = None
|
||||
bus = MagicMock()
|
||||
channel = FeishuChannel(config, bus)
|
||||
channel._client = MagicMock() # Simulate initialized client
|
||||
return channel
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_sends_code_message(mock_feishu_channel):
|
||||
"""Tool hint messages should be sent as interactive cards with code blocks."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='web_search("test query")',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
# Verify interactive message with card was sent
|
||||
assert mock_send.call_count == 1
|
||||
call_args = mock_send.call_args[0]
|
||||
receive_id_type, receive_id, msg_type, content = call_args
|
||||
|
||||
assert receive_id_type == "chat_id"
|
||||
assert receive_id == "oc_123456"
|
||||
assert msg_type == "interactive"
|
||||
|
||||
# Parse content to verify card structure
|
||||
card = json.loads(content)
|
||||
assert card["config"]["wide_screen_mode"] is True
|
||||
assert len(card["elements"]) == 1
|
||||
assert card["elements"][0]["tag"] == "markdown"
|
||||
# Check that code block is properly formatted with language hint
|
||||
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"test query\")\n```"
|
||||
assert card["elements"][0]["content"] == expected_md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_empty_content_does_not_send(mock_feishu_channel):
|
||||
"""Empty tool hint messages should not be sent."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content=" ", # whitespace only
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
# Should not send any message
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
|
||||
"""Regular messages without _tool_hint should use normal formatting."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content="Hello, world!",
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
# Should send as text message (detected format)
|
||||
assert mock_send.call_count == 1
|
||||
call_args = mock_send.call_args[0]
|
||||
_, _, msg_type, content = call_args
|
||||
assert msg_type == "text"
|
||||
assert json.loads(content) == {"text": "Hello, world!"}
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
|
||||
"""Multiple tool calls should be displayed each on its own line in a code block."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='web_search("query"), read_file("/path/to/file")',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
call_args = mock_send.call_args[0]
|
||||
msg_type = call_args[2]
|
||||
content = json.loads(call_args[3])
|
||||
assert msg_type == "interactive"
|
||||
# Each tool call should be on its own line
|
||||
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"query\"),\nread_file(\"/path/to/file\")\n```"
|
||||
assert content["elements"][0]["content"] == expected_md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
|
||||
"""Commas inside a single tool argument must not be split onto a new line."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='web_search("foo, bar"), read_file("/path/to/file")',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
expected_md = (
|
||||
"**Tool Calls**\n\n```text\n"
|
||||
"web_search(\"foo, bar\"),\n"
|
||||
"read_file(\"/path/to/file\")\n```"
|
||||
)
|
||||
assert content["elements"][0]["content"] == expected_md
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,172 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional QQ dependencies before running tests
|
||||
try:
|
||||
from nanobot.channels import qq
|
||||
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
|
||||
if not QQ_AVAILABLE:
|
||||
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import QQChannel, QQConfig
|
||||
|
||||
|
||||
class _FakeApi:
|
||||
def __init__(self) -> None:
|
||||
self.c2c_calls: list[dict] = []
|
||||
self.group_calls: list[dict] = []
|
||||
|
||||
async def post_c2c_message(self, **kwargs) -> None:
|
||||
self.c2c_calls.append(kwargs)
|
||||
|
||||
async def post_group_message(self, **kwargs) -> None:
|
||||
self.group_calls.append(kwargs)
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self) -> None:
|
||||
self.api = _FakeApi()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_group_message_routes_to_group_chat_id() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus())
|
||||
|
||||
data = SimpleNamespace(
|
||||
id="msg1",
|
||||
content="hello",
|
||||
group_openid="group123",
|
||||
author=SimpleNamespace(member_openid="user1"),
|
||||
attachments=[],
|
||||
)
|
||||
|
||||
await channel._on_message(data, is_group=True)
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
channel._chat_type_cache["group123"] = "group"
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="group123",
|
||||
content="hello",
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(channel._client.api.group_calls) == 1
|
||||
call = channel._client.api.group_calls[0]
|
||||
assert call == {
|
||||
"group_openid": "group123",
|
||||
"msg_type": 0,
|
||||
"content": "hello",
|
||||
"msg_id": "msg1",
|
||||
"msg_seq": 2,
|
||||
}
|
||||
assert not channel._client.api.c2c_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user123",
|
||||
content="hello",
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(channel._client.api.c2c_calls) == 1
|
||||
call = channel._client.api.c2c_calls[0]
|
||||
assert call == {
|
||||
"openid": "user123",
|
||||
"msg_type": 0,
|
||||
"content": "hello",
|
||||
"msg_id": "msg1",
|
||||
"msg_seq": 2,
|
||||
}
|
||||
assert not channel._client.api.group_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_group_message_uses_markdown_when_configured() -> None:
|
||||
channel = QQChannel(
|
||||
QQConfig(app_id="app", secret="secret", allow_from=["*"], msg_format="markdown"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._client = _FakeClient()
|
||||
channel._chat_type_cache["group123"] = "group"
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="group123",
|
||||
content="**hello**",
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(channel._client.api.group_calls) == 1
|
||||
call = channel._client.api.group_calls[0]
|
||||
assert call == {
|
||||
"group_openid": "group123",
|
||||
"msg_type": 2,
|
||||
"markdown": {"content": "**hello**"},
|
||||
"msg_id": "msg1",
|
||||
"msg_seq": 2,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_media_bytes_local_path() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp_path = f.name
|
||||
|
||||
data, filename = await channel._read_media_bytes(tmp_path)
|
||||
assert data == b"\x89PNG\r\n"
|
||||
assert filename == Path(tmp_path).name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_media_bytes_file_uri() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"JFIF")
|
||||
tmp_path = f.name
|
||||
|
||||
data, filename = await channel._read_media_bytes(f"file://{tmp_path}")
|
||||
assert data == b"JFIF"
|
||||
assert filename == Path(tmp_path).name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_media_bytes_missing_file() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
|
||||
data, filename = await channel._read_media_bytes("/nonexistent/path/image.png")
|
||||
assert data is None
|
||||
assert filename is None
|
||||
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional Slack dependencies before running tests
|
||||
try:
|
||||
import slack_sdk # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("Slack dependencies not installed (slack-sdk)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
from nanobot.channels.slack import SlackConfig
|
||||
|
||||
|
||||
class _FakeAsyncWebClient:
|
||||
def __init__(self) -> None:
|
||||
self.chat_post_calls: list[dict[str, object | None]] = []
|
||||
self.file_upload_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_add_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_remove_calls: list[dict[str, object | None]] = []
|
||||
|
||||
async def chat_postMessage(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
text: str,
|
||||
thread_ts: str | None = None,
|
||||
) -> None:
|
||||
self.chat_post_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"text": text,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
)
|
||||
|
||||
async def files_upload_v2(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
file: str,
|
||||
thread_ts: str | None = None,
|
||||
) -> None:
|
||||
self.file_upload_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"file": file,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
)
|
||||
|
||||
async def reactions_add(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
name: str,
|
||||
timestamp: str,
|
||||
) -> None:
|
||||
self.reactions_add_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"name": name,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
)
|
||||
|
||||
async def reactions_remove(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
name: str,
|
||||
timestamp: str,
|
||||
) -> None:
|
||||
self.reactions_remove_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"name": name,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_thread_for_channel_messages() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
content="hello",
|
||||
media=["/tmp/demo.txt"],
|
||||
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "channel"}},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 1
|
||||
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||
assert fake_web.chat_post_calls[0]["thread_ts"] == "1700000000.000100"
|
||||
assert len(fake_web.file_upload_calls) == 1
|
||||
assert fake_web.file_upload_calls[0]["thread_ts"] == "1700000000.000100"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_omits_thread_for_dm_messages() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="D123",
|
||||
content="hello",
|
||||
media=["/tmp/demo.txt"],
|
||||
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "im"}},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 1
|
||||
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||
assert fake_web.chat_post_calls[0]["thread_ts"] is None
|
||||
assert len(fake_web.file_upload_calls) == 1
|
||||
assert fake_web.file_upload_calls[0]["thread_ts"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_updates_reaction_when_final_response_sent() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
content="done",
|
||||
metadata={
|
||||
"slack": {"event": {"ts": "1700000000.000100"}, "channel_type": "channel"},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.reactions_remove_calls == [
|
||||
{"channel": "C123", "name": "eyes", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
assert fake_web.reactions_add_calls == [
|
||||
{"channel": "C123", "name": "white_check_mark", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
@@ -0,0 +1,966 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Check optional Telegram dependencies before running tests
|
||||
try:
|
||||
import telegram # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("Telegram dependencies not installed (python-telegram-bot)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel, _StreamBuf
|
||||
from nanobot.channels.telegram import TelegramConfig
|
||||
|
||||
|
||||
class _FakeHTTPXRequest:
|
||||
instances: list["_FakeHTTPXRequest"] = []
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.__class__.instances.append(self)
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
cls.instances.clear()
|
||||
|
||||
|
||||
class _FakeUpdater:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self._on_start_polling = on_start_polling
|
||||
|
||||
async def start_polling(self, **kwargs) -> None:
|
||||
self._on_start_polling()
|
||||
|
||||
|
||||
class _FakeBot:
|
||||
def __init__(self) -> None:
|
||||
self.sent_messages: list[dict] = []
|
||||
self.sent_media: list[dict] = []
|
||||
self.get_me_calls = 0
|
||||
|
||||
async def get_me(self):
|
||||
self.get_me_calls += 1
|
||||
return SimpleNamespace(id=999, username="nanobot_test")
|
||||
|
||||
async def set_my_commands(self, commands) -> None:
|
||||
self.commands = commands
|
||||
|
||||
async def send_message(self, **kwargs):
|
||||
self.sent_messages.append(kwargs)
|
||||
return SimpleNamespace(message_id=len(self.sent_messages))
|
||||
|
||||
async def send_photo(self, **kwargs) -> None:
|
||||
self.sent_media.append({"kind": "photo", **kwargs})
|
||||
|
||||
async def send_voice(self, **kwargs) -> None:
|
||||
self.sent_media.append({"kind": "voice", **kwargs})
|
||||
|
||||
async def send_audio(self, **kwargs) -> None:
|
||||
self.sent_media.append({"kind": "audio", **kwargs})
|
||||
|
||||
async def send_document(self, **kwargs) -> None:
|
||||
self.sent_media.append({"kind": "document", **kwargs})
|
||||
|
||||
async def send_chat_action(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def get_file(self, file_id: str):
|
||||
"""Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
|
||||
async def _fake_download(path) -> None:
|
||||
pass
|
||||
return SimpleNamespace(download_to_drive=_fake_download)
|
||||
|
||||
|
||||
class _FakeApp:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self.bot = _FakeBot()
|
||||
self.updater = _FakeUpdater(on_start_polling)
|
||||
self.handlers = []
|
||||
self.error_handlers = []
|
||||
|
||||
def add_error_handler(self, handler) -> None:
|
||||
self.error_handlers.append(handler)
|
||||
|
||||
def add_handler(self, handler) -> None:
|
||||
self.handlers.append(handler)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeBuilder:
|
||||
def __init__(self, app: _FakeApp) -> None:
|
||||
self.app = app
|
||||
self.token_value = None
|
||||
self.request_value = None
|
||||
self.get_updates_request_value = None
|
||||
|
||||
def token(self, token: str):
|
||||
self.token_value = token
|
||||
return self
|
||||
|
||||
def request(self, request):
|
||||
self.request_value = request
|
||||
return self
|
||||
|
||||
def get_updates_request(self, request):
|
||||
self.get_updates_request_value = request
|
||||
return self
|
||||
|
||||
def proxy(self, _proxy):
|
||||
raise AssertionError("builder.proxy should not be called when request is set")
|
||||
|
||||
def get_updates_proxy(self, _proxy):
|
||||
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
|
||||
|
||||
def build(self):
|
||||
return self.app
|
||||
|
||||
|
||||
def _make_telegram_update(
|
||||
*,
|
||||
chat_type: str = "group",
|
||||
text: str | None = None,
|
||||
caption: str | None = None,
|
||||
entities=None,
|
||||
caption_entities=None,
|
||||
reply_to_message=None,
|
||||
):
|
||||
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type=chat_type, is_forum=False),
|
||||
chat_id=-100123,
|
||||
text=text,
|
||||
caption=caption,
|
||||
entities=entities or [],
|
||||
caption_entities=caption_entities or [],
|
||||
reply_to_message=reply_to_message,
|
||||
photo=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
media_group_id=None,
|
||||
message_thread_id=None,
|
||||
message_id=1,
|
||||
)
|
||||
return SimpleNamespace(message=message, effective_user=user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_separate_pools_with_proxy(monkeypatch) -> None:
|
||||
_FakeHTTPXRequest.clear()
|
||||
config = TelegramConfig(
|
||||
enabled=True,
|
||||
token="123:abc",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(config, bus)
|
||||
app = _FakeApp(lambda: setattr(channel, "_running", False))
|
||||
builder = _FakeBuilder(app)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.Application",
|
||||
SimpleNamespace(builder=lambda: builder),
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert len(_FakeHTTPXRequest.instances) == 2
|
||||
api_req, poll_req = _FakeHTTPXRequest.instances
|
||||
assert api_req.kwargs["proxy"] == config.proxy
|
||||
assert poll_req.kwargs["proxy"] == config.proxy
|
||||
assert api_req.kwargs["connection_pool_size"] == 32
|
||||
assert poll_req.kwargs["connection_pool_size"] == 4
|
||||
assert builder.request_value is api_req
|
||||
assert builder.get_updates_request_value is poll_req
|
||||
assert any(cmd.command == "status" for cmd in app.bot.commands)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_respects_custom_pool_config(monkeypatch) -> None:
|
||||
_FakeHTTPXRequest.clear()
|
||||
config = TelegramConfig(
|
||||
enabled=True,
|
||||
token="123:abc",
|
||||
allow_from=["*"],
|
||||
connection_pool_size=32,
|
||||
pool_timeout=10.0,
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(config, bus)
|
||||
app = _FakeApp(lambda: setattr(channel, "_running", False))
|
||||
builder = _FakeBuilder(app)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.Application",
|
||||
SimpleNamespace(builder=lambda: builder),
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
|
||||
api_req = _FakeHTTPXRequest.instances[0]
|
||||
poll_req = _FakeHTTPXRequest.instances[1]
|
||||
assert api_req.kwargs["connection_pool_size"] == 32
|
||||
assert api_req.kwargs["pool_timeout"] == 10.0
|
||||
assert poll_req.kwargs["pool_timeout"] == 10.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_retries_on_timeout() -> None:
|
||||
"""_send_text retries on TimedOut before succeeding."""
|
||||
from telegram.error import TimedOut
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
call_count = 0
|
||||
original_send = channel._app.bot.send_message
|
||||
|
||||
async def flaky_send(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
raise TimedOut()
|
||||
return await original_send(**kwargs)
|
||||
|
||||
channel._app.bot.send_message = flaky_send
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
assert call_count == 3
|
||||
assert len(channel._app.bot.sent_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_gives_up_after_max_retries() -> None:
|
||||
"""_send_text raises TimedOut after exhausting all retries."""
|
||||
from telegram.error import TimedOut
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
async def always_timeout(**kwargs):
|
||||
raise TimedOut()
|
||||
|
||||
channel._app.bot.send_message = always_timeout
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
with pytest.raises(TimedOut):
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
assert channel._app.bot.sent_messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_error_logs_network_issues_as_warning(monkeypatch) -> None:
|
||||
from telegram.error import NetworkError
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
recorded: list[tuple[str, str]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.logger.warning",
|
||||
lambda message, error: recorded.append(("warning", message.format(error))),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.logger.error",
|
||||
lambda message, error: recorded.append(("error", message.format(error))),
|
||||
)
|
||||
|
||||
await channel._on_error(object(), SimpleNamespace(error=NetworkError("proxy disconnected")))
|
||||
|
||||
assert recorded == [("warning", "Telegram network issue: proxy disconnected")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_error_keeps_non_network_exceptions_as_error(monkeypatch) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
recorded: list[tuple[str, str]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.logger.warning",
|
||||
lambda message, error: recorded.append(("warning", message.format(error))),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.logger.error",
|
||||
lambda message, error: recorded.append(("error", message.format(error))),
|
||||
)
|
||||
|
||||
await channel._on_error(object(), SimpleNamespace(error=RuntimeError("boom")))
|
||||
|
||||
assert recorded == [("error", "Telegram error: boom")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_raises_and_keeps_buffer_on_failure() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await channel.send_delta("123", "", {"_stream_end": True})
|
||||
|
||||
assert "123" in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_treats_not_modified_as_success() -> None:
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified"))
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0")
|
||||
|
||||
await channel.send_delta("123", "", {"_stream_end": True, "_stream_id": "s:0"})
|
||||
|
||||
assert "123" not in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_new_stream_id_replaces_stale_buffer() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._stream_bufs["123"] = _StreamBuf(
|
||||
text="hello",
|
||||
message_id=7,
|
||||
last_edit=0.0,
|
||||
stream_id="old:0",
|
||||
)
|
||||
|
||||
await channel.send_delta("123", "world", {"_stream_delta": True, "_stream_id": "new:0"})
|
||||
|
||||
buf = channel._stream_bufs["123"]
|
||||
assert buf.text == "world"
|
||||
assert buf.stream_id == "new:0"
|
||||
assert buf.message_id == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_incremental_edit_treats_not_modified_as_success() -> None:
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0")
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified"))
|
||||
|
||||
await channel.send_delta("123", "", {"_stream_delta": True, "_stream_id": "s:0"})
|
||||
|
||||
assert channel._stream_bufs["123"].last_edit > 0.0
|
||||
|
||||
|
||||
def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type="supergroup"),
|
||||
chat_id=-100123,
|
||||
message_thread_id=42,
|
||||
)
|
||||
|
||||
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
|
||||
|
||||
|
||||
def test_get_extension_falls_back_to_original_filename() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(), MessageBus())
|
||||
|
||||
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
|
||||
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
|
||||
|
||||
|
||||
def test_telegram_group_policy_defaults_to_mention() -> None:
|
||||
assert TelegramConfig().group_policy == "mention"
|
||||
|
||||
|
||||
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("12345|carol") is True
|
||||
assert channel.is_allowed("99999|alice") is True
|
||||
assert channel.is_allowed("67890|bob") is True
|
||||
|
||||
|
||||
def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("attacker|alice|extra") is False
|
||||
assert channel.is_allowed("not-a-number|alice") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_keeps_message_in_topic() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"_progress": True, "message_thread_id": 42},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._message_threads[("123", 10)] = 42
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"message_id": 10},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_remote_media_url_after_security_validation(monkeypatch) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
monkeypatch.setattr("nanobot.channels.telegram.validate_url_target", lambda url: (True, ""))
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="",
|
||||
media=["https://example.com/cat.jpg"],
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_media == [
|
||||
{
|
||||
"kind": "photo",
|
||||
"chat_id": 123,
|
||||
"photo": "https://example.com/cat.jpg",
|
||||
"reply_parameters": None,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_blocks_unsafe_remote_media_url(monkeypatch) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.validate_url_target",
|
||||
lambda url: (False, "Blocked: example.com resolves to private/internal address 127.0.0.1"),
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="",
|
||||
media=["http://example.com/internal.jpg"],
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_media == []
|
||||
assert channel._app.bot.sent_messages == [
|
||||
{
|
||||
"chat_id": 123,
|
||||
"text": "[Failed to send: internal.jpg]",
|
||||
"reply_parameters": None,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
|
||||
|
||||
assert handled == []
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
|
||||
|
||||
assert len(handled) == 2
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_caption_mention() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(
|
||||
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
|
||||
None,
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "@nanobot_test photo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
|
||||
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_open_accepts_plain_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello group"), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert channel._app.bot.get_me_calls == 0
|
||||
|
||||
|
||||
def test_extract_reply_context_no_reply() -> None:
|
||||
"""When there is no reply_to_message, _extract_reply_context returns None."""
|
||||
message = SimpleNamespace(reply_to_message=None)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
def test_extract_reply_context_with_text() -> None:
|
||||
"""When reply has text, return prefixed string."""
|
||||
reply = SimpleNamespace(text="Hello world", caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
|
||||
|
||||
|
||||
def test_extract_reply_context_with_caption_only() -> None:
|
||||
"""When reply has only caption (no text), caption is used."""
|
||||
reply = SimpleNamespace(text=None, caption="Photo caption")
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
|
||||
|
||||
|
||||
def test_extract_reply_context_truncation() -> None:
|
||||
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
|
||||
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
|
||||
reply = SimpleNamespace(text=long_text, caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
result = TelegramChannel._extract_reply_context(message)
|
||||
assert result is not None
|
||||
assert result.startswith("[Reply to: ")
|
||||
assert result.endswith("...]")
|
||||
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
|
||||
|
||||
|
||||
def test_extract_reply_context_no_text_returns_none() -> None:
|
||||
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
|
||||
reply = SimpleNamespace(text=None, caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_includes_reply_context() -> None:
|
||||
"""When user replies to a message, content passed to bus starts with reply context."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
|
||||
update = _make_telegram_update(text="translate this", reply_to_message=reply)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"].startswith("[Reply to: Hello]")
|
||||
assert "translate this" in handled[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_message_media_returns_path_when_download_succeeds(
|
||||
monkeypatch, tmp_path
|
||||
) -> None:
|
||||
"""_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
|
||||
msg = SimpleNamespace(
|
||||
photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
paths, parts = await channel._download_message_media(msg)
|
||||
assert len(paths) == 1
|
||||
assert len(parts) == 1
|
||||
assert "fid123" in paths[0]
|
||||
assert "[image:" in parts[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_message_media_uses_file_unique_id_when_available(
|
||||
monkeypatch, tmp_path
|
||||
) -> None:
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
downloaded: dict[str, str] = {}
|
||||
|
||||
async def _download_to_drive(path: str) -> None:
|
||||
downloaded["path"] = path
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=_download_to_drive)
|
||||
)
|
||||
channel._app = app
|
||||
|
||||
msg = SimpleNamespace(
|
||||
photo=[
|
||||
SimpleNamespace(
|
||||
file_id="file-id-that-should-not-be-used",
|
||||
file_unique_id="stable-unique-id",
|
||||
mime_type="image/jpeg",
|
||||
file_name=None,
|
||||
)
|
||||
],
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
|
||||
paths, parts = await channel._download_message_media(msg)
|
||||
|
||||
assert downloaded["path"].endswith("stable-unique-id.jpg")
|
||||
assert paths == [str(media_dir / "stable-unique-id.jpg")]
|
||||
assert parts == [f"[image: {media_dir / 'stable-unique-id.jpg'}]"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
|
||||
"""When user replies to a message with media, that media is downloaded and attached to the turn."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
channel._app = app
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption=None,
|
||||
photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(
|
||||
text="what is the image?",
|
||||
reply_to_message=reply_with_photo,
|
||||
)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"].startswith("[Reply to: [image:")
|
||||
assert "what is the image?" in handled[0]["content"]
|
||||
assert len(handled[0]["media"]) == 1
|
||||
assert "reply_photo_fid" in handled[0]["media"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
|
||||
"""When reply has media but download fails, no media attached and no reply tag."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.get_file = None
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption=None,
|
||||
photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert "what is this?" in handled[0]["content"]
|
||||
assert handled[0]["media"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
|
||||
"""When replying to a message with caption + photo, both text context and media are included."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
channel._app = app
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_caption_and_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption="A cute cat",
|
||||
photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(
|
||||
text="what breed is this?",
|
||||
reply_to_message=reply_with_caption_and_photo,
|
||||
)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert "[Reply to: A cute cat]" in handled[0]["content"]
|
||||
assert "what breed is this?" in handled[0]["content"]
|
||||
assert len(handled[0]["media"]) == 1
|
||||
assert "cat_fid" in handled[0]["media"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_command_does_not_inject_reply_context() -> None:
|
||||
"""Slash commands forwarded via _forward_command must not include reply context."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
|
||||
reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
|
||||
update = _make_telegram_update(text="/new", reply_to_message=reply)
|
||||
await channel._forward_command(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_help_includes_restart_command() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
update = _make_telegram_update(text="/help", chat_type="private")
|
||||
update.message.reply_text = AsyncMock()
|
||||
|
||||
await channel._on_help(update, None)
|
||||
|
||||
update.message.reply_text.assert_awaited_once()
|
||||
help_text = update.message.reply_text.await_args.args[0]
|
||||
assert "/restart" in help_text
|
||||
assert "/status" in help_text
|
||||
@@ -0,0 +1,280 @@
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.weixin import (
|
||||
ITEM_IMAGE,
|
||||
ITEM_TEXT,
|
||||
MESSAGE_TYPE_BOT,
|
||||
WEIXIN_CHANNEL_VERSION,
|
||||
WeixinChannel,
|
||||
WeixinConfig,
|
||||
)
|
||||
|
||||
|
||||
def _make_channel() -> tuple[WeixinChannel, MessageBus]:
|
||||
bus = MessageBus()
|
||||
channel = WeixinChannel(
|
||||
WeixinConfig(
|
||||
enabled=True,
|
||||
allow_from=["*"],
|
||||
state_dir=tempfile.mkdtemp(prefix="nanobot-weixin-test-"),
|
||||
),
|
||||
bus,
|
||||
)
|
||||
return channel, bus
|
||||
|
||||
|
||||
def test_make_headers_includes_route_tag_when_configured() -> None:
|
||||
bus = MessageBus()
|
||||
channel = WeixinChannel(
|
||||
WeixinConfig(enabled=True, allow_from=["*"], route_tag=123),
|
||||
bus,
|
||||
)
|
||||
channel._token = "token"
|
||||
|
||||
headers = channel._make_headers()
|
||||
|
||||
assert headers["Authorization"] == "Bearer token"
|
||||
assert headers["SKRouteTag"] == "123"
|
||||
|
||||
|
||||
def test_channel_version_matches_reference_plugin_version() -> None:
|
||||
assert WEIXIN_CHANNEL_VERSION == "1.0.3"
|
||||
|
||||
|
||||
def test_save_and_load_state_persists_context_tokens(tmp_path) -> None:
|
||||
bus = MessageBus()
|
||||
channel = WeixinChannel(
|
||||
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
|
||||
bus,
|
||||
)
|
||||
channel._token = "token"
|
||||
channel._get_updates_buf = "cursor"
|
||||
channel._context_tokens = {"wx-user": "ctx-1"}
|
||||
|
||||
channel._save_state()
|
||||
|
||||
saved = json.loads((tmp_path / "account.json").read_text())
|
||||
assert saved["context_tokens"] == {"wx-user": "ctx-1"}
|
||||
|
||||
restored = WeixinChannel(
|
||||
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
|
||||
bus,
|
||||
)
|
||||
|
||||
assert restored._load_state() is True
|
||||
assert restored._context_tokens == {"wx-user": "ctx-1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_deduplicates_inbound_ids() -> None:
|
||||
channel, bus = _make_channel()
|
||||
msg = {
|
||||
"message_type": 1,
|
||||
"message_id": "m1",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-1",
|
||||
"item_list": [
|
||||
{"type": ITEM_TEXT, "text_item": {"text": "hello"}},
|
||||
],
|
||||
}
|
||||
|
||||
await channel._process_message(msg)
|
||||
first = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
|
||||
await channel._process_message(msg)
|
||||
|
||||
assert first.sender_id == "wx-user"
|
||||
assert first.chat_id == "wx-user"
|
||||
assert first.content == "hello"
|
||||
assert bus.inbound_size == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_caches_context_token_and_send_uses_it() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": 1,
|
||||
"message_id": "m2",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-2",
|
||||
"item_list": [
|
||||
{"type": ITEM_TEXT, "text_item": {"text": "ping"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
|
||||
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None:
|
||||
bus = MessageBus()
|
||||
channel = WeixinChannel(
|
||||
WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)),
|
||||
bus,
|
||||
)
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": 1,
|
||||
"message_id": "m2b",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-2b",
|
||||
"item_list": [
|
||||
{"type": ITEM_TEXT, "text_item": {"text": "ping"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
saved = json.loads((tmp_path / "account.json").read_text())
|
||||
assert saved["context_tokens"] == {"wx-user": "ctx-2b"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_extracts_media_and_preserves_paths() -> None:
|
||||
channel, bus = _make_channel()
|
||||
channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg")
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": 1,
|
||||
"message_id": "m3",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-3",
|
||||
"item_list": [
|
||||
{"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "x"}}},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
|
||||
|
||||
assert "[image]" in inbound.content
|
||||
assert "/tmp/test.jpg" in inbound.content
|
||||
assert inbound.media == ["/tmp/test.jpg"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_without_context_token_does_not_send_text() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "unknown-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_does_not_send_when_session_is_paused() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-2"
|
||||
channel._pause_session(60)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_pauses_session_on_expired_errcode() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = SimpleNamespace(timeout=None)
|
||||
channel._token = "token"
|
||||
channel._api_post = AsyncMock(return_value={"ret": 0, "errcode": -14, "errmsg": "expired"})
|
||||
|
||||
await channel._poll_once()
|
||||
|
||||
assert channel._session_pause_remaining_s() > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._save_state = lambda: None
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._api_get = AsyncMock(
|
||||
side_effect=[
|
||||
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-2",
|
||||
"ilink_bot_id": "bot-2",
|
||||
"baseurl": "https://example.test",
|
||||
"ilink_user_id": "wx-user",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is True
|
||||
assert channel._token == "token-2"
|
||||
assert channel.config.base_url == "https://example.test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._api_get = AsyncMock(
|
||||
side_effect=[
|
||||
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-3", "qrcode_img_content": "url-3"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-4", "qrcode_img_content": "url-4"},
|
||||
{"status": "expired"},
|
||||
]
|
||||
)
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_skips_bot_messages() -> None:
|
||||
channel, bus = _make_channel()
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": MESSAGE_TYPE_BOT,
|
||||
"message_id": "m4",
|
||||
"from_user_id": "wx-user",
|
||||
"item_list": [
|
||||
{"type": ITEM_TEXT, "text_item": {"text": "hello"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
assert bus.inbound_size == 0
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Tests for WhatsApp channel outbound media support."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.whatsapp import WhatsAppChannel
|
||||
|
||||
|
||||
def _make_channel() -> WhatsAppChannel:
|
||||
bus = MagicMock()
|
||||
ch = WhatsAppChannel({"enabled": True}, bus)
|
||||
ch._ws = AsyncMock()
|
||||
ch._connected = True
|
||||
return ch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_only():
|
||||
ch = _make_channel()
|
||||
msg = OutboundMessage(channel="whatsapp", chat_id="123@s.whatsapp.net", content="hello")
|
||||
|
||||
await ch.send(msg)
|
||||
|
||||
ch._ws.send.assert_called_once()
|
||||
payload = json.loads(ch._ws.send.call_args[0][0])
|
||||
assert payload["type"] == "send"
|
||||
assert payload["text"] == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_dispatches_send_media_command():
|
||||
ch = _make_channel()
|
||||
msg = OutboundMessage(
|
||||
channel="whatsapp",
|
||||
chat_id="123@s.whatsapp.net",
|
||||
content="check this out",
|
||||
media=["/tmp/photo.jpg"],
|
||||
)
|
||||
|
||||
await ch.send(msg)
|
||||
|
||||
assert ch._ws.send.call_count == 2
|
||||
text_payload = json.loads(ch._ws.send.call_args_list[0][0][0])
|
||||
media_payload = json.loads(ch._ws.send.call_args_list[1][0][0])
|
||||
|
||||
assert text_payload["type"] == "send"
|
||||
assert text_payload["text"] == "check this out"
|
||||
|
||||
assert media_payload["type"] == "send_media"
|
||||
assert media_payload["filePath"] == "/tmp/photo.jpg"
|
||||
assert media_payload["mimetype"] == "image/jpeg"
|
||||
assert media_payload["fileName"] == "photo.jpg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_only_no_text():
|
||||
ch = _make_channel()
|
||||
msg = OutboundMessage(
|
||||
channel="whatsapp",
|
||||
chat_id="123@s.whatsapp.net",
|
||||
content="",
|
||||
media=["/tmp/doc.pdf"],
|
||||
)
|
||||
|
||||
await ch.send(msg)
|
||||
|
||||
ch._ws.send.assert_called_once()
|
||||
payload = json.loads(ch._ws.send.call_args[0][0])
|
||||
assert payload["type"] == "send_media"
|
||||
assert payload["mimetype"] == "application/pdf"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_multiple_media():
|
||||
ch = _make_channel()
|
||||
msg = OutboundMessage(
|
||||
channel="whatsapp",
|
||||
chat_id="123@s.whatsapp.net",
|
||||
content="",
|
||||
media=["/tmp/a.png", "/tmp/b.mp4"],
|
||||
)
|
||||
|
||||
await ch.send(msg)
|
||||
|
||||
assert ch._ws.send.call_count == 2
|
||||
p1 = json.loads(ch._ws.send.call_args_list[0][0][0])
|
||||
p2 = json.loads(ch._ws.send.call_args_list[1][0][0])
|
||||
assert p1["mimetype"] == "image/png"
|
||||
assert p2["mimetype"] == "video/mp4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_when_disconnected_is_noop():
|
||||
ch = _make_channel()
|
||||
ch._connected = False
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="whatsapp",
|
||||
chat_id="123@s.whatsapp.net",
|
||||
content="hello",
|
||||
media=["/tmp/x.jpg"],
|
||||
)
|
||||
await ch.send(msg)
|
||||
|
||||
ch._ws.send.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_skips_unmentioned_group_message():
|
||||
ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock())
|
||||
ch._handle_message = AsyncMock()
|
||||
|
||||
await ch._handle_bridge_message(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "message",
|
||||
"id": "m1",
|
||||
"sender": "12345@g.us",
|
||||
"pn": "user@s.whatsapp.net",
|
||||
"content": "hello group",
|
||||
"timestamp": 1,
|
||||
"isGroup": True,
|
||||
"wasMentioned": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
ch._handle_message.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_mentioned_group_message():
|
||||
ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock())
|
||||
ch._handle_message = AsyncMock()
|
||||
|
||||
await ch._handle_bridge_message(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "message",
|
||||
"id": "m1",
|
||||
"sender": "12345@g.us",
|
||||
"pn": "user@s.whatsapp.net",
|
||||
"content": "hello @bot",
|
||||
"timestamp": 1,
|
||||
"isGroup": True,
|
||||
"wasMentioned": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
ch._handle_message.assert_awaited_once()
|
||||
kwargs = ch._handle_message.await_args.kwargs
|
||||
assert kwargs["chat_id"] == "12345@g.us"
|
||||
assert kwargs["sender_id"] == "user"
|
||||
Reference in New Issue
Block a user