Update 2026-05-13 16:43:53
This commit is contained in:
@@ -0,0 +1,264 @@
|
||||
"""Channel manager for coordinating chat channels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
# Retry delays for message sending (exponential backoff: 1s, 2s, 4s)
|
||||
_SEND_RETRY_DELAYS = (1, 2, 4)
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manages chat channels and coordinates message routing.
|
||||
|
||||
Responsibilities:
|
||||
- Initialize enabled channels (Telegram, WhatsApp, etc.)
|
||||
- Start/stop channels
|
||||
- Route outbound messages
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, bus: MessageBus):
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
|
||||
self._init_channels()
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
|
||||
from nanobot.channels.registry import discover_all
|
||||
|
||||
groq_key = self.config.providers.groq.api_key
|
||||
|
||||
for name, cls in discover_all().items():
|
||||
section = getattr(self.config.channels, name, None)
|
||||
if section is None:
|
||||
continue
|
||||
enabled = (
|
||||
section.get("enabled", False)
|
||||
if isinstance(section, dict)
|
||||
else getattr(section, "enabled", False)
|
||||
)
|
||||
if not enabled:
|
||||
continue
|
||||
try:
|
||||
channel = cls(section, self.bus)
|
||||
channel.transcription_api_key = groq_key
|
||||
self.channels[name] = channel
|
||||
logger.info("{} channel enabled", cls.display_name)
|
||||
except Exception as e:
|
||||
logger.warning("{} channel not available: {}", name, e)
|
||||
|
||||
self._validate_allow_from()
|
||||
|
||||
def _validate_allow_from(self) -> None:
|
||||
for name, ch in self.channels.items():
|
||||
if getattr(ch.config, "allow_from", None) == []:
|
||||
raise SystemExit(
|
||||
f'Error: "{name}" has empty allowFrom (denies all). '
|
||||
f'Set ["*"] to allow everyone, or add specific user IDs.'
|
||||
)
|
||||
|
||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||
"""Start a channel and log any exceptions."""
|
||||
try:
|
||||
await channel.start()
|
||||
except Exception as e:
|
||||
logger.error("Failed to start channel {}: {}", name, e)
|
||||
|
||||
async def start_all(self) -> None:
|
||||
"""Start all channels and the outbound dispatcher."""
|
||||
if not self.channels:
|
||||
logger.warning("No channels enabled")
|
||||
return
|
||||
|
||||
# Start outbound dispatcher
|
||||
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
||||
|
||||
# Start channels
|
||||
tasks = []
|
||||
for name, channel in self.channels.items():
|
||||
logger.info("Starting {} channel...", name)
|
||||
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
||||
|
||||
# Wait for all to complete (they should run forever)
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all channels and the dispatcher."""
|
||||
logger.info("Stopping all channels...")
|
||||
|
||||
# Stop dispatcher
|
||||
if self._dispatch_task:
|
||||
self._dispatch_task.cancel()
|
||||
try:
|
||||
await self._dispatch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop all channels
|
||||
for name, channel in self.channels.items():
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Stopped {} channel", name)
|
||||
except Exception as e:
|
||||
logger.error("Error stopping {}: {}", name, e)
|
||||
|
||||
async def _dispatch_outbound(self) -> None:
|
||||
"""Dispatch outbound messages to the appropriate channel."""
|
||||
logger.info("Outbound dispatcher started")
|
||||
|
||||
# Buffer for messages that couldn't be processed during delta coalescing
|
||||
# (since asyncio.Queue doesn't support push_front)
|
||||
pending: list[OutboundMessage] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
# First check pending buffer before waiting on queue
|
||||
if pending:
|
||||
msg = pending.pop(0)
|
||||
else:
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_outbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
if msg.metadata.get("_progress"):
|
||||
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||
continue
|
||||
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||
continue
|
||||
|
||||
# Coalesce consecutive _stream_delta messages for the same (channel, chat_id)
|
||||
# to reduce API calls and improve streaming latency
|
||||
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):
|
||||
msg, extra_pending = self._coalesce_stream_deltas(msg)
|
||||
pending.extend(extra_pending)
|
||||
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
await self._send_with_retry(channel, msg)
|
||||
else:
|
||||
logger.warning("Unknown channel: {}", msg.channel)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
async def _send_once(channel: BaseChannel, msg: OutboundMessage) -> None:
|
||||
"""Send one outbound message without retry policy."""
|
||||
if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"):
|
||||
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
|
||||
elif not msg.metadata.get("_streamed"):
|
||||
await channel.send(msg)
|
||||
|
||||
def _coalesce_stream_deltas(
|
||||
self, first_msg: OutboundMessage
|
||||
) -> tuple[OutboundMessage, list[OutboundMessage]]:
|
||||
"""Merge consecutive _stream_delta messages for the same (channel, chat_id).
|
||||
|
||||
This reduces the number of API calls when the queue has accumulated multiple
|
||||
deltas, which happens when LLM generates faster than the channel can process.
|
||||
|
||||
Returns:
|
||||
tuple of (merged_message, list_of_non_matching_messages)
|
||||
"""
|
||||
target_key = (first_msg.channel, first_msg.chat_id)
|
||||
combined_content = first_msg.content
|
||||
final_metadata = dict(first_msg.metadata or {})
|
||||
non_matching: list[OutboundMessage] = []
|
||||
|
||||
# Only merge consecutive deltas. As soon as we hit any other message,
|
||||
# stop and hand that boundary back to the dispatcher via `pending`.
|
||||
while True:
|
||||
try:
|
||||
next_msg = self.bus.outbound.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# Check if this message belongs to the same stream
|
||||
same_target = (next_msg.channel, next_msg.chat_id) == target_key
|
||||
is_delta = next_msg.metadata and next_msg.metadata.get("_stream_delta")
|
||||
is_end = next_msg.metadata and next_msg.metadata.get("_stream_end")
|
||||
|
||||
if same_target and is_delta and not final_metadata.get("_stream_end"):
|
||||
# Accumulate content
|
||||
combined_content += next_msg.content
|
||||
# If we see _stream_end, remember it and stop coalescing this stream
|
||||
if is_end:
|
||||
final_metadata["_stream_end"] = True
|
||||
# Stream ended - stop coalescing this stream
|
||||
break
|
||||
else:
|
||||
# First non-matching message defines the coalescing boundary.
|
||||
non_matching.append(next_msg)
|
||||
break
|
||||
|
||||
merged = OutboundMessage(
|
||||
channel=first_msg.channel,
|
||||
chat_id=first_msg.chat_id,
|
||||
content=combined_content,
|
||||
metadata=final_metadata,
|
||||
)
|
||||
return merged, non_matching
|
||||
|
||||
async def _send_with_retry(self, channel: BaseChannel, msg: OutboundMessage) -> None:
|
||||
"""Send a message with retry on failure using exponential backoff.
|
||||
|
||||
Note: CancelledError is re-raised to allow graceful shutdown.
|
||||
"""
|
||||
max_attempts = max(self.config.channels.send_max_retries, 1)
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
await self._send_once(channel, msg)
|
||||
return # Send succeeded
|
||||
except asyncio.CancelledError:
|
||||
raise # Propagate cancellation for graceful shutdown
|
||||
except Exception as e:
|
||||
if attempt == max_attempts - 1:
|
||||
logger.error(
|
||||
"Failed to send to {} after {} attempts: {} - {}",
|
||||
msg.channel, max_attempts, type(e).__name__, e
|
||||
)
|
||||
return
|
||||
delay = _SEND_RETRY_DELAYS[min(attempt, len(_SEND_RETRY_DELAYS) - 1)]
|
||||
logger.warning(
|
||||
"Send to {} failed (attempt {}/{}): {}, retrying in {}s",
|
||||
msg.channel, attempt + 1, max_attempts, type(e).__name__, delay
|
||||
)
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
except asyncio.CancelledError:
|
||||
raise # Propagate cancellation during sleep
|
||||
|
||||
def get_channel(self, name: str) -> BaseChannel | None:
|
||||
"""Get a channel by name."""
|
||||
return self.channels.get(name)
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""Get status of all channels."""
|
||||
return {
|
||||
name: {
|
||||
"enabled": True,
|
||||
"running": channel.is_running
|
||||
}
|
||||
for name, channel in self.channels.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def enabled_channels(self) -> list[str]:
|
||||
"""Get list of enabled channel names."""
|
||||
return list(self.channels.keys())
|
||||
Reference in New Issue
Block a user