From 3f16c4a20e8ce555c4bf99c144d26494e84cd02e Mon Sep 17 00:00:00 2001 From: ekko <152005280+EKKOLearnAI@users.noreply.github.com> Date: Fri, 22 May 2026 15:52:00 +0800 Subject: [PATCH] Allow bridge sessions to run concurrently (#932) * Allow bridge sessions to run concurrently * Stabilize bridge concurrency test * Set bridge approval timeout to 120 seconds * harden bridge approval concurrency --------- Co-authored-by: Codex --- .../hermes/agent-bridge/hermes_bridge.py | 250 ++++++----- .../agent-bridge-python-concurrency.test.ts | 398 ++++++++++++++++++ 2 files changed, 547 insertions(+), 101 deletions(-) create mode 100644 tests/server/agent-bridge-python-concurrency.test.ts diff --git a/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py b/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py index 63a38c6..87500a9 100755 --- a/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py +++ b/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py @@ -30,12 +30,14 @@ from contextlib import contextmanager from dataclasses import dataclass, field from pathlib import Path from urllib.parse import urlparse -from typing import Any +from typing import Any, Callable DEFAULT_ENDPOINT = "tcp://127.0.0.1:18765" if os.name == "nt" else "ipc:///tmp/hermes-agent-bridge.sock" DEFAULT_AGENT_ROOT = "~/.hermes/hermes-agent" DEFAULT_HERMES_HOME = "~/.hermes" +APPROVAL_TIMEOUT_SECONDS = 120 +APPROVAL_TIMEOUT_MS = APPROVAL_TIMEOUT_SECONDS * 1000 def _bridge_platform() -> str: @@ -501,11 +503,14 @@ class AgentPool: self._sessions: dict[str, AgentSession] = {} self._runs: dict[str, RunRecord] = {} self._lock = threading.RLock() - self._run_lock = threading.Lock() self._db = SessionDbHolder() self._approval_requests: dict[str, queue.Queue[str]] = {} self._gateway_approval_requests: dict[str, str] = {} self._compression_requests: dict[str, queue.Queue[dict[str, Any]]] = {} + self._run_context = threading.local() + self._approval_handlers: dict[str, Callable[..., str]] = {} + self._exec_ask_depth = 0 + self._exec_ask_previous: str | None = None def get_or_create( self, @@ -927,10 +932,10 @@ class AgentPool: "description": str(description or ""), "choices": choices, "allow_permanent": bool(allow_permanent), - "timeout_ms": 60_000, + "timeout_ms": APPROVAL_TIMEOUT_MS, }) try: - choice = response_queue.get(timeout=60) + choice = response_queue.get(timeout=APPROVAL_TIMEOUT_SECONDS) except queue.Empty: choice = "deny" finally: @@ -945,6 +950,44 @@ class AgentPool: return callback + def _approval_dispatcher(self, command: str, description: str, *, allow_permanent: bool = True) -> str: + session_id = str(getattr(self._run_context, "session_id", "") or "") + if not session_id: + return "deny" + with self._lock: + handler = self._approval_handlers.get(session_id) + if handler is None: + return "deny" + return handler(command, description, allow_permanent=allow_permanent) + + def _install_approval_dispatcher_for_current_thread(self) -> None: + from tools.terminal_tool import set_approval_callback + + # terminal_tool stores callbacks in threading.local(), so each run + # thread must bind the shared dispatcher for itself. + set_approval_callback(self._approval_dispatcher) + + def _enter_exec_ask_scope(self) -> None: + with self._lock: + if self._exec_ask_depth == 0: + self._exec_ask_previous = os.environ.get("HERMES_EXEC_ASK") + os.environ["HERMES_EXEC_ASK"] = "1" + self._exec_ask_depth += 1 + + def _exit_exec_ask_scope(self) -> None: + with self._lock: + if self._exec_ask_depth <= 0: + return + self._exec_ask_depth -= 1 + if self._exec_ask_depth > 0: + return + previous = self._exec_ask_previous + self._exec_ask_previous = None + if previous is None: + os.environ.pop("HERMES_EXEC_ASK", None) + else: + os.environ["HERMES_EXEC_ASK"] = previous + def _gateway_approval_notify(self, session_id: str): def callback(approval_data: dict[str, Any]) -> None: approval_id = uuid.uuid4().hex @@ -1124,102 +1167,103 @@ class AgentPool: return record def _run_chat(self, session: AgentSession, record: RunRecord, message: Any, storage_message: Any | None = None, instructions: str | None = None, conversation_history: list[dict[str, Any]] | None = None, profile: str | None = None, force_compress: bool = False, source: str | None = None) -> None: - with self._run_lock: - with _profile_env(profile): - def stream_callback(delta: str) -> None: - with self._lock: - record.deltas.append(str(delta)) + with _profile_env(profile): + def stream_callback(delta: str) -> None: + with self._lock: + record.deltas.append(str(delta)) + approval_session_token = None + registered_gateway_approval_session = None + exec_ask_scope_entered = False + try: try: - previous_approval_callback = None - previous_exec_ask = os.environ.get("HERMES_EXEC_ASK") - approval_session_token = None - registered_gateway_approval_session = None - try: - from tools.terminal_tool import _get_approval_callback, set_approval_callback - from tools.approval import register_gateway_notify, set_current_session_key + self._enter_exec_ask_scope() + exec_ask_scope_entered = True + self._install_approval_dispatcher_for_current_thread() + with self._lock: + self._approval_handlers[session.session_id] = self._approval_callback(session.session_id) + self._run_context.session_id = session.session_id + except Exception: + self._run_context.session_id = session.session_id + try: + from tools.approval import register_gateway_notify, set_current_session_key - previous_approval_callback = _get_approval_callback() - set_approval_callback(self._approval_callback(session.session_id)) - approval_session_token = set_current_session_key(session.session_id) - register_gateway_notify(session.session_id, self._gateway_approval_notify(session.session_id)) - registered_gateway_approval_session = session.session_id - os.environ["HERMES_EXEC_ASK"] = "1" - except Exception: - previous_approval_callback = None - self._prepersist_user_message(session, message, storage_message, conversation_history, profile, source) - db_count_after_prepersist = self._session_db_message_count(session.session_id, profile) - if force_compress: - compress = getattr(session.agent, "_compress_context", None) - if callable(compress): - compressed_history, compressed_system = compress( - conversation_history if isinstance(conversation_history, list) else [], - instructions, - approx_tokens=None, - focus_topic="debug_force_compress", - ) - if isinstance(compressed_history, list): - conversation_history = compressed_history - if isinstance(compressed_system, str): - instructions = compressed_system - kwargs: dict[str, Any] = dict( - task_id=session.session_id, - stream_callback=stream_callback, - ) - if instructions: - kwargs["system_message"] = instructions - if conversation_history is not None: - kwargs["conversation_history"] = conversation_history - result = session.agent.run_conversation( - message, - **kwargs, - ) - result = _jsonable(result if isinstance(result, dict) else {"value": result}) - self._sync_result_tail_to_session_db( - session, - result, - conversation_history, - profile, - db_count_after_prepersist, - ) - with session.lock: - if isinstance(result.get("messages"), list): - session.history = result["messages"] - record.status = "interrupted" if result.get("interrupted") else "complete" - record.result = result - record.ended_at = time.time() - session.running = False - session.current_run_id = None - session.last_used_at = time.time() - except Exception as exc: - with session.lock: - record.status = "error" - record.error = str(exc) - record.result = {"error": str(exc), "traceback": traceback.format_exc()} - record.ended_at = time.time() - session.running = False - session.current_run_id = None - session.last_used_at = time.time() - finally: + approval_session_token = set_current_session_key(session.session_id) + register_gateway_notify(session.session_id, self._gateway_approval_notify(session.session_id)) + registered_gateway_approval_session = session.session_id + except Exception: + pass + self._prepersist_user_message(session, message, storage_message, conversation_history, profile, source) + db_count_after_prepersist = self._session_db_message_count(session.session_id, profile) + if force_compress: + compress = getattr(session.agent, "_compress_context", None) + if callable(compress): + compressed_history, compressed_system = compress( + conversation_history if isinstance(conversation_history, list) else [], + instructions, + approx_tokens=None, + focus_topic="debug_force_compress", + ) + if isinstance(compressed_history, list): + conversation_history = compressed_history + if isinstance(compressed_system, str): + instructions = compressed_system + kwargs: dict[str, Any] = dict( + task_id=session.session_id, + stream_callback=stream_callback, + ) + if instructions: + kwargs["system_message"] = instructions + if conversation_history is not None: + kwargs["conversation_history"] = conversation_history + result = session.agent.run_conversation( + message, + **kwargs, + ) + result = _jsonable(result if isinstance(result, dict) else {"value": result}) + self._sync_result_tail_to_session_db( + session, + result, + conversation_history, + profile, + db_count_after_prepersist, + ) + with session.lock: + if isinstance(result.get("messages"), list): + session.history = result["messages"] + record.status = "interrupted" if result.get("interrupted") else "complete" + record.result = result + record.ended_at = time.time() + session.running = False + session.current_run_id = None + session.last_used_at = time.time() + except Exception as exc: + with session.lock: + record.status = "error" + record.error = str(exc) + record.result = {"error": str(exc), "traceback": traceback.format_exc()} + record.ended_at = time.time() + session.running = False + session.current_run_id = None + session.last_used_at = time.time() + finally: + with self._lock: + self._approval_handlers.pop(session.session_id, None) + try: + del self._run_context.session_id + except AttributeError: + pass + if approval_session_token is not None: try: - from tools.terminal_tool import set_approval_callback + from tools.approval import reset_current_session_key, unregister_gateway_notify - set_approval_callback(previous_approval_callback) + if registered_gateway_approval_session is not None: + unregister_gateway_notify(registered_gateway_approval_session) + reset_current_session_key(approval_session_token) except Exception: pass - if approval_session_token is not None: - try: - from tools.approval import reset_current_session_key, unregister_gateway_notify - - if registered_gateway_approval_session is not None: - unregister_gateway_notify(registered_gateway_approval_session) - reset_current_session_key(approval_session_token) - except Exception: - pass - if previous_exec_ask is None: - os.environ.pop("HERMES_EXEC_ASK", None) - else: - os.environ["HERMES_EXEC_ASK"] = previous_exec_ask + if exec_ask_scope_entered: + self._exit_exec_ask_scope() def interrupt(self, session_id: str, message: str | None = None) -> dict[str, Any]: with self._lock: @@ -2043,26 +2087,27 @@ class BridgeBroker: if action == "destroy_all": with self._lock: workers = list(self._workers.values()) + self._workers.clear() self._run_profile.clear() self._session_profile.clear() self._approval_profile.clear() self._compression_profile.clear() destroyed = 0 for worker in workers: - if not worker.running: - worker.stop() - continue try: - resp = worker.request({"action": "destroy_all"}) - destroyed += int(resp.get("destroyed") or 0) + if worker.running: + resp = worker.request({"action": "destroy_all"}) + destroyed += int(resp.get("destroyed") or 0) except Exception: pass + finally: + worker.stop() return {"destroyed": destroyed} if action == "destroy_profile": profile = self._normalize_profile(req.get("profile")) with self._lock: - worker = self._workers.get(profile) + worker = self._workers.pop(profile, None) self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile} self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile} self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile} @@ -2075,9 +2120,12 @@ class BridgeBroker: try: resp = worker.request({"action": "destroy_all"}) - return {"profile": profile, "destroyed": int(resp.get("destroyed") or 0)} + destroyed = int(resp.get("destroyed") or 0) except Exception: - return {"profile": profile, "destroyed": 0} + destroyed = 0 + finally: + worker.stop() + return {"profile": profile, "destroyed": destroyed} if action == "list": sessions: list[Any] = [] diff --git a/tests/server/agent-bridge-python-concurrency.test.ts b/tests/server/agent-bridge-python-concurrency.test.ts new file mode 100644 index 0000000..d736d4d --- /dev/null +++ b/tests/server/agent-bridge-python-concurrency.test.ts @@ -0,0 +1,398 @@ +import { execFileSync } from 'child_process' +import { describe, it } from 'vitest' + +function runPython(script: string): void { + try { + execFileSync('python3', ['-c', script], { + cwd: process.cwd(), + encoding: 'utf-8', + stdio: 'pipe', + }) + } catch (error) { + const err = error as { stdout?: string; stderr?: string; message?: string } + throw new Error([ + err.message || 'Python bridge concurrency script failed', + err.stdout ? `stdout:\n${err.stdout}` : '', + err.stderr ? `stderr:\n${err.stderr}` : '', + ].filter(Boolean).join('\n\n')) + } +} + +const harness = String.raw` +import contextvars +import importlib.util +import os +import sys +import threading +import time +import types +from pathlib import Path + +os.environ["HERMES_AGENT_BRIDGE_WORKER_PROFILE"] = "default" + +tools_pkg = types.ModuleType("tools") +tools_pkg.__path__ = [] +sys.modules["tools"] = tools_pkg + +terminal_tool = types.ModuleType("tools.terminal_tool") +terminal_tool._callback_tls = threading.local() + +def set_approval_callback(callback): + terminal_tool._callback_tls.callback = callback + +def _get_approval_callback(): + return getattr(terminal_tool._callback_tls, "callback", None) + +terminal_tool.set_approval_callback = set_approval_callback +terminal_tool._get_approval_callback = _get_approval_callback +sys.modules["tools.terminal_tool"] = terminal_tool + +approval = types.ModuleType("tools.approval") +approval._session_key = contextvars.ContextVar("approval_session_key", default="") +approval._notify = {} +approval._resolved_gateway = [] + +def set_current_session_key(session_key): + return approval._session_key.set(session_key or "") + +def reset_current_session_key(token): + approval._session_key.reset(token) + +def get_current_session_key(default=""): + return approval._session_key.get() or default + +def register_gateway_notify(session_key, callback): + approval._notify[session_key] = callback + +def unregister_gateway_notify(session_key): + approval._notify.pop(session_key, None) + +def resolve_gateway_approval(session_key, choice): + approval._resolved_gateway.append((session_key, choice)) + return 1 + +approval.set_current_session_key = set_current_session_key +approval.reset_current_session_key = reset_current_session_key +approval.get_current_session_key = get_current_session_key +approval.register_gateway_notify = register_gateway_notify +approval.unregister_gateway_notify = unregister_gateway_notify +approval.resolve_gateway_approval = resolve_gateway_approval +sys.modules["tools.approval"] = approval + +path = Path("packages/server/src/services/hermes/agent-bridge/hermes_bridge.py") +spec = importlib.util.spec_from_file_location("hermes_bridge", path) +bridge = importlib.util.module_from_spec(spec) +sys.modules[spec.name] = bridge +spec.loader.exec_module(bridge) + +class FakeDb: + def __init__(self): + self.lock = threading.Lock() + self.messages = {} + self.sessions = set() + + def create_session(self, session_id, **kwargs): + with self.lock: + self.sessions.add(session_id) + self.messages.setdefault(session_id, []) + + def get_messages(self, session_id): + with self.lock: + return list(self.messages.get(session_id, [])) + + def append_message(self, session_id, role, content=None, **kwargs): + with self.lock: + self.messages.setdefault(session_id, []).append({ + "role": role, + "content": content, + **kwargs, + }) + +class FakeDbHolder: + error = None + + def __init__(self, db): + self.db = db + + def get_for_profile(self, profile): + return self.db + +def make_pool(): + pool = bridge.AgentPool() + fake_db = FakeDb() + pool._db = FakeDbHolder(fake_db) + return pool, fake_db + +def start_manual_run(pool, session_id, agent, message=None): + session = bridge.AgentSession(session_id=session_id, agent=agent) + run_id = f"run-{session_id}" + record = bridge.RunRecord(run_id=run_id, session_id=session_id) + session.running = True + session.current_run_id = run_id + with pool._lock: + pool._sessions[session_id] = session + pool._runs[run_id] = record + thread = threading.Thread( + target=pool._run_chat, + args=(session, record, message or f"message:{session_id}", None, None, [], "default", False, "api_server"), + daemon=True, + ) + thread.start() + return session, record, thread + +def wait_for(condition, timeout=20): + deadline = time.time() + timeout + while time.time() < deadline: + if condition(): + return True + time.sleep(0.01) + return False +` + +describe('agent bridge Python session concurrency', () => { + it('routes terminal/gateway approvals and stream callbacks per concurrent session', () => { + runPython(String.raw` +${harness} + +barrier = threading.Barrier(2) +os.environ["HERMES_EXEC_ASK"] = "preexisting-exec-ask" + +class FakeAgent: + def __init__(self, session_id): + self.session_id = session_id + + def run_conversation(self, message, **kwargs): + barrier.wait(timeout=20) + notify = approval._notify.get(self.session_id) + if notify is None: + raise RuntimeError(f"missing gateway notify for {self.session_id}") + notify({ + "command": f"gateway:{self.session_id}", + "description": f"gateway-desc:{self.session_id}", + }) + kwargs["stream_callback"](f"delta:{self.session_id}") + callback = _get_approval_callback() + if callback is None: + raise RuntimeError(f"missing approval callback for {self.session_id}") + assert get_current_session_key("") == self.session_id + choice = callback(f"cmd:{self.session_id}", f"desc:{self.session_id}", allow_permanent=False) + return { + "messages": [{"role": "assistant", "content": f"done:{self.session_id}:{choice}"}], + "choice": choice, + "completed": True, + } + +pool, fake_db = make_pool() +records = {} +threads = [] + +for sid in ("session-a", "session-b"): + _session, record, thread = start_manual_run(pool, sid, FakeAgent(sid)) + records[sid] = record + threads.append(thread) + +terminal_approval_ids = {} +gateway_approval_ids = {} +def approvals_ready(): + with pool._lock: + for sid, record in records.items(): + for event in record.events: + if event.get("event") != "approval.requested": + continue + command = event.get("command") + if command == f"cmd:{sid}": + terminal_approval_ids[sid] = event["approval_id"] + if command == f"gateway:{sid}": + gateway_approval_ids[sid] = event["approval_id"] + return ( + set(terminal_approval_ids) == {"session-a", "session-b"} and + set(gateway_approval_ids) == {"session-a", "session-b"} + ) + +if not wait_for(approvals_ready): + diagnostics = { + sid: { + "status": record.status, + "error": record.error, + "events": record.events, + "result": record.result, + } + for sid, record in records.items() + } + raise AssertionError({ + "terminal_approval_ids": terminal_approval_ids, + "gateway_approval_ids": gateway_approval_ids, + "records": diagnostics, + }) + +assert os.environ.get("HERMES_EXEC_ASK") == "1" +assert pool._exec_ask_depth == 2 + +pool.respond_approval(gateway_approval_ids["session-b"], "always") +pool.respond_approval(gateway_approval_ids["session-a"], "session") +pool.respond_approval(terminal_approval_ids["session-b"], "deny") +pool.respond_approval(terminal_approval_ids["session-a"], "once") + +for thread in threads: + thread.join(timeout=20) + assert not thread.is_alive() + +assert records["session-a"].status == "complete" +assert records["session-b"].status == "complete" +assert records["session-a"].result["choice"] == "once" +assert records["session-b"].result["choice"] == "deny" +assert records["session-a"].deltas == ["delta:session-a"] +assert records["session-b"].deltas == ["delta:session-b"] +assert fake_db.get_messages("session-a")[0]["content"] == "message:session-a" +assert fake_db.get_messages("session-b")[0]["content"] == "message:session-b" +assert os.environ.get("HERMES_EXEC_ASK") == "preexisting-exec-ask" +assert pool._exec_ask_depth == 0 +assert pool._approval_handlers == {} +assert approval._notify == {} +assert sorted(approval._resolved_gateway) == [ + ("session-a", "session"), + ("session-b", "always"), +] + +terminal_commands = {} +gateway_commands = {} +timeouts = {} +for sid, record in records.items(): + for event in record.events: + if event.get("event") != "approval.requested": + continue + command = event.get("command") + if command == f"cmd:{sid}": + terminal_commands[sid] = command + timeouts[sid] = event.get("timeout_ms") + if command == f"gateway:{sid}": + gateway_commands[sid] = command + +assert terminal_commands == { + "session-a": "cmd:session-a", + "session-b": "cmd:session-b", +} +assert gateway_commands == { + "session-a": "gateway:session-a", + "session-b": "gateway:session-b", +} +assert timeouts == { + "session-a": 120000, + "session-b": 120000, +} + +same_session = bridge.AgentSession(session_id="same-session", agent=FakeAgent("same-session")) +same_session.running = True +pool.get_or_create = lambda *args, **kwargs: same_session +try: + pool.start_chat("same-session", "second") + raise AssertionError("same-session concurrent run was accepted") +except RuntimeError as exc: + assert "already running" in str(exc) + +class FakeWorker: + def __init__(self, destroyed): + self.running = True + self.destroyed = destroyed + self.requests = [] + self.stopped = False + + def request(self, req): + self.requests.append(req) + return {"ok": True, "destroyed": self.destroyed} + + def stop(self): + self.running = False + self.stopped = True + +broker = bridge.BridgeBroker("ipc:///tmp/unused.sock") +profile_worker = FakeWorker(2) +broker._workers["default"] = profile_worker +broker._run_profile["run-session-a"] = "default" +broker._session_profile["session-a"] = "default" +broker._approval_profile["approval-a"] = "default" +broker._compression_profile["compression-a"] = "default" + +destroy_profile_result = broker.handle({"action": "destroy_profile", "profile": "default"}) +assert destroy_profile_result == {"profile": "default", "destroyed": 2} +assert profile_worker.stopped +assert "default" not in broker._workers +assert broker._run_profile == {} +assert broker._session_profile == {} +assert broker._approval_profile == {} +assert broker._compression_profile == {} + +worker_a = FakeWorker(1) +worker_b = FakeWorker(3) +broker._workers["a"] = worker_a +broker._workers["b"] = worker_b +broker._run_profile["run-a"] = "a" +broker._session_profile["session-b"] = "b" + +destroy_all_result = broker.handle({"action": "destroy_all"}) +assert destroy_all_result == {"destroyed": 4} +assert worker_a.stopped +assert worker_b.stopped +assert broker._workers == {} +assert broker._run_profile == {} +assert broker._session_profile == {} +`) + }) + + it('restores approval env and clears handlers when a run fails', () => { + runPython(String.raw` +${harness} + +os.environ.pop("HERMES_EXEC_ASK", None) + +class FailingAgent: + def run_conversation(self, message, **kwargs): + assert os.environ.get("HERMES_EXEC_ASK") == "1" + assert _get_approval_callback() is not None + raise RuntimeError("boom") + +pool, fake_db = make_pool() +session, record, thread = start_manual_run(pool, "error-session", FailingAgent()) +thread.join(timeout=20) +assert not thread.is_alive() + +assert record.status == "error" +assert "boom" in (record.error or "") +assert session.running is False +assert session.current_run_id is None +assert "HERMES_EXEC_ASK" not in os.environ +assert pool._exec_ask_depth == 0 +assert pool._exec_ask_previous is None +assert pool._approval_handlers == {} +assert approval._notify == {} +assert fake_db.get_messages("error-session")[0]["content"] == "message:error-session" +`) + }) + + it('fails closed when approval dispatch loses run thread context', () => { + runPython(String.raw` +${harness} + +pool, _fake_db = make_pool() +calls = [] + +def handler(command, description, *, allow_permanent=True): + calls.append((command, description, allow_permanent)) + return "once" + +with pool._lock: + pool._approval_handlers["session-a"] = handler + +assert pool._approval_dispatcher("cmd", "desc") == "deny" +assert calls == [] + +pool._run_context.session_id = "missing-session" +assert pool._approval_dispatcher("cmd", "desc") == "deny" +assert calls == [] + +pool._run_context.session_id = "session-a" +assert pool._approval_dispatcher("cmd", "desc", allow_permanent=False) == "once" +assert calls == [("cmd", "desc", False)] +`) + }) +})