Allow bridge sessions to run concurrently (#932)
* Allow bridge sessions to run concurrently * Stabilize bridge concurrency test * Set bridge approval timeout to 120 seconds * harden bridge approval concurrency --------- Co-authored-by: Codex <codex@openai.com>
This commit is contained in:
@@ -30,12 +30,14 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
DEFAULT_ENDPOINT = "tcp://127.0.0.1:18765" if os.name == "nt" else "ipc:///tmp/hermes-agent-bridge.sock"
|
||||
DEFAULT_AGENT_ROOT = "~/.hermes/hermes-agent"
|
||||
DEFAULT_HERMES_HOME = "~/.hermes"
|
||||
APPROVAL_TIMEOUT_SECONDS = 120
|
||||
APPROVAL_TIMEOUT_MS = APPROVAL_TIMEOUT_SECONDS * 1000
|
||||
|
||||
|
||||
def _bridge_platform() -> str:
|
||||
@@ -501,11 +503,14 @@ class AgentPool:
|
||||
self._sessions: dict[str, AgentSession] = {}
|
||||
self._runs: dict[str, RunRecord] = {}
|
||||
self._lock = threading.RLock()
|
||||
self._run_lock = threading.Lock()
|
||||
self._db = SessionDbHolder()
|
||||
self._approval_requests: dict[str, queue.Queue[str]] = {}
|
||||
self._gateway_approval_requests: dict[str, str] = {}
|
||||
self._compression_requests: dict[str, queue.Queue[dict[str, Any]]] = {}
|
||||
self._run_context = threading.local()
|
||||
self._approval_handlers: dict[str, Callable[..., str]] = {}
|
||||
self._exec_ask_depth = 0
|
||||
self._exec_ask_previous: str | None = None
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
@@ -927,10 +932,10 @@ class AgentPool:
|
||||
"description": str(description or ""),
|
||||
"choices": choices,
|
||||
"allow_permanent": bool(allow_permanent),
|
||||
"timeout_ms": 60_000,
|
||||
"timeout_ms": APPROVAL_TIMEOUT_MS,
|
||||
})
|
||||
try:
|
||||
choice = response_queue.get(timeout=60)
|
||||
choice = response_queue.get(timeout=APPROVAL_TIMEOUT_SECONDS)
|
||||
except queue.Empty:
|
||||
choice = "deny"
|
||||
finally:
|
||||
@@ -945,6 +950,44 @@ class AgentPool:
|
||||
|
||||
return callback
|
||||
|
||||
def _approval_dispatcher(self, command: str, description: str, *, allow_permanent: bool = True) -> str:
|
||||
session_id = str(getattr(self._run_context, "session_id", "") or "")
|
||||
if not session_id:
|
||||
return "deny"
|
||||
with self._lock:
|
||||
handler = self._approval_handlers.get(session_id)
|
||||
if handler is None:
|
||||
return "deny"
|
||||
return handler(command, description, allow_permanent=allow_permanent)
|
||||
|
||||
def _install_approval_dispatcher_for_current_thread(self) -> None:
|
||||
from tools.terminal_tool import set_approval_callback
|
||||
|
||||
# terminal_tool stores callbacks in threading.local(), so each run
|
||||
# thread must bind the shared dispatcher for itself.
|
||||
set_approval_callback(self._approval_dispatcher)
|
||||
|
||||
def _enter_exec_ask_scope(self) -> None:
|
||||
with self._lock:
|
||||
if self._exec_ask_depth == 0:
|
||||
self._exec_ask_previous = os.environ.get("HERMES_EXEC_ASK")
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
self._exec_ask_depth += 1
|
||||
|
||||
def _exit_exec_ask_scope(self) -> None:
|
||||
with self._lock:
|
||||
if self._exec_ask_depth <= 0:
|
||||
return
|
||||
self._exec_ask_depth -= 1
|
||||
if self._exec_ask_depth > 0:
|
||||
return
|
||||
previous = self._exec_ask_previous
|
||||
self._exec_ask_previous = None
|
||||
if previous is None:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
else:
|
||||
os.environ["HERMES_EXEC_ASK"] = previous
|
||||
|
||||
def _gateway_approval_notify(self, session_id: str):
|
||||
def callback(approval_data: dict[str, Any]) -> None:
|
||||
approval_id = uuid.uuid4().hex
|
||||
@@ -1124,102 +1167,103 @@ class AgentPool:
|
||||
return record
|
||||
|
||||
def _run_chat(self, session: AgentSession, record: RunRecord, message: Any, storage_message: Any | None = None, instructions: str | None = None, conversation_history: list[dict[str, Any]] | None = None, profile: str | None = None, force_compress: bool = False, source: str | None = None) -> None:
|
||||
with self._run_lock:
|
||||
with _profile_env(profile):
|
||||
def stream_callback(delta: str) -> None:
|
||||
with self._lock:
|
||||
record.deltas.append(str(delta))
|
||||
with _profile_env(profile):
|
||||
def stream_callback(delta: str) -> None:
|
||||
with self._lock:
|
||||
record.deltas.append(str(delta))
|
||||
|
||||
approval_session_token = None
|
||||
registered_gateway_approval_session = None
|
||||
exec_ask_scope_entered = False
|
||||
try:
|
||||
try:
|
||||
previous_approval_callback = None
|
||||
previous_exec_ask = os.environ.get("HERMES_EXEC_ASK")
|
||||
approval_session_token = None
|
||||
registered_gateway_approval_session = None
|
||||
try:
|
||||
from tools.terminal_tool import _get_approval_callback, set_approval_callback
|
||||
from tools.approval import register_gateway_notify, set_current_session_key
|
||||
self._enter_exec_ask_scope()
|
||||
exec_ask_scope_entered = True
|
||||
self._install_approval_dispatcher_for_current_thread()
|
||||
with self._lock:
|
||||
self._approval_handlers[session.session_id] = self._approval_callback(session.session_id)
|
||||
self._run_context.session_id = session.session_id
|
||||
except Exception:
|
||||
self._run_context.session_id = session.session_id
|
||||
try:
|
||||
from tools.approval import register_gateway_notify, set_current_session_key
|
||||
|
||||
previous_approval_callback = _get_approval_callback()
|
||||
set_approval_callback(self._approval_callback(session.session_id))
|
||||
approval_session_token = set_current_session_key(session.session_id)
|
||||
register_gateway_notify(session.session_id, self._gateway_approval_notify(session.session_id))
|
||||
registered_gateway_approval_session = session.session_id
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
except Exception:
|
||||
previous_approval_callback = None
|
||||
self._prepersist_user_message(session, message, storage_message, conversation_history, profile, source)
|
||||
db_count_after_prepersist = self._session_db_message_count(session.session_id, profile)
|
||||
if force_compress:
|
||||
compress = getattr(session.agent, "_compress_context", None)
|
||||
if callable(compress):
|
||||
compressed_history, compressed_system = compress(
|
||||
conversation_history if isinstance(conversation_history, list) else [],
|
||||
instructions,
|
||||
approx_tokens=None,
|
||||
focus_topic="debug_force_compress",
|
||||
)
|
||||
if isinstance(compressed_history, list):
|
||||
conversation_history = compressed_history
|
||||
if isinstance(compressed_system, str):
|
||||
instructions = compressed_system
|
||||
kwargs: dict[str, Any] = dict(
|
||||
task_id=session.session_id,
|
||||
stream_callback=stream_callback,
|
||||
)
|
||||
if instructions:
|
||||
kwargs["system_message"] = instructions
|
||||
if conversation_history is not None:
|
||||
kwargs["conversation_history"] = conversation_history
|
||||
result = session.agent.run_conversation(
|
||||
message,
|
||||
**kwargs,
|
||||
)
|
||||
result = _jsonable(result if isinstance(result, dict) else {"value": result})
|
||||
self._sync_result_tail_to_session_db(
|
||||
session,
|
||||
result,
|
||||
conversation_history,
|
||||
profile,
|
||||
db_count_after_prepersist,
|
||||
)
|
||||
with session.lock:
|
||||
if isinstance(result.get("messages"), list):
|
||||
session.history = result["messages"]
|
||||
record.status = "interrupted" if result.get("interrupted") else "complete"
|
||||
record.result = result
|
||||
record.ended_at = time.time()
|
||||
session.running = False
|
||||
session.current_run_id = None
|
||||
session.last_used_at = time.time()
|
||||
except Exception as exc:
|
||||
with session.lock:
|
||||
record.status = "error"
|
||||
record.error = str(exc)
|
||||
record.result = {"error": str(exc), "traceback": traceback.format_exc()}
|
||||
record.ended_at = time.time()
|
||||
session.running = False
|
||||
session.current_run_id = None
|
||||
session.last_used_at = time.time()
|
||||
finally:
|
||||
approval_session_token = set_current_session_key(session.session_id)
|
||||
register_gateway_notify(session.session_id, self._gateway_approval_notify(session.session_id))
|
||||
registered_gateway_approval_session = session.session_id
|
||||
except Exception:
|
||||
pass
|
||||
self._prepersist_user_message(session, message, storage_message, conversation_history, profile, source)
|
||||
db_count_after_prepersist = self._session_db_message_count(session.session_id, profile)
|
||||
if force_compress:
|
||||
compress = getattr(session.agent, "_compress_context", None)
|
||||
if callable(compress):
|
||||
compressed_history, compressed_system = compress(
|
||||
conversation_history if isinstance(conversation_history, list) else [],
|
||||
instructions,
|
||||
approx_tokens=None,
|
||||
focus_topic="debug_force_compress",
|
||||
)
|
||||
if isinstance(compressed_history, list):
|
||||
conversation_history = compressed_history
|
||||
if isinstance(compressed_system, str):
|
||||
instructions = compressed_system
|
||||
kwargs: dict[str, Any] = dict(
|
||||
task_id=session.session_id,
|
||||
stream_callback=stream_callback,
|
||||
)
|
||||
if instructions:
|
||||
kwargs["system_message"] = instructions
|
||||
if conversation_history is not None:
|
||||
kwargs["conversation_history"] = conversation_history
|
||||
result = session.agent.run_conversation(
|
||||
message,
|
||||
**kwargs,
|
||||
)
|
||||
result = _jsonable(result if isinstance(result, dict) else {"value": result})
|
||||
self._sync_result_tail_to_session_db(
|
||||
session,
|
||||
result,
|
||||
conversation_history,
|
||||
profile,
|
||||
db_count_after_prepersist,
|
||||
)
|
||||
with session.lock:
|
||||
if isinstance(result.get("messages"), list):
|
||||
session.history = result["messages"]
|
||||
record.status = "interrupted" if result.get("interrupted") else "complete"
|
||||
record.result = result
|
||||
record.ended_at = time.time()
|
||||
session.running = False
|
||||
session.current_run_id = None
|
||||
session.last_used_at = time.time()
|
||||
except Exception as exc:
|
||||
with session.lock:
|
||||
record.status = "error"
|
||||
record.error = str(exc)
|
||||
record.result = {"error": str(exc), "traceback": traceback.format_exc()}
|
||||
record.ended_at = time.time()
|
||||
session.running = False
|
||||
session.current_run_id = None
|
||||
session.last_used_at = time.time()
|
||||
finally:
|
||||
with self._lock:
|
||||
self._approval_handlers.pop(session.session_id, None)
|
||||
try:
|
||||
del self._run_context.session_id
|
||||
except AttributeError:
|
||||
pass
|
||||
if approval_session_token is not None:
|
||||
try:
|
||||
from tools.terminal_tool import set_approval_callback
|
||||
from tools.approval import reset_current_session_key, unregister_gateway_notify
|
||||
|
||||
set_approval_callback(previous_approval_callback)
|
||||
if registered_gateway_approval_session is not None:
|
||||
unregister_gateway_notify(registered_gateway_approval_session)
|
||||
reset_current_session_key(approval_session_token)
|
||||
except Exception:
|
||||
pass
|
||||
if approval_session_token is not None:
|
||||
try:
|
||||
from tools.approval import reset_current_session_key, unregister_gateway_notify
|
||||
|
||||
if registered_gateway_approval_session is not None:
|
||||
unregister_gateway_notify(registered_gateway_approval_session)
|
||||
reset_current_session_key(approval_session_token)
|
||||
except Exception:
|
||||
pass
|
||||
if previous_exec_ask is None:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
else:
|
||||
os.environ["HERMES_EXEC_ASK"] = previous_exec_ask
|
||||
if exec_ask_scope_entered:
|
||||
self._exit_exec_ask_scope()
|
||||
|
||||
def interrupt(self, session_id: str, message: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
@@ -2043,26 +2087,27 @@ class BridgeBroker:
|
||||
if action == "destroy_all":
|
||||
with self._lock:
|
||||
workers = list(self._workers.values())
|
||||
self._workers.clear()
|
||||
self._run_profile.clear()
|
||||
self._session_profile.clear()
|
||||
self._approval_profile.clear()
|
||||
self._compression_profile.clear()
|
||||
destroyed = 0
|
||||
for worker in workers:
|
||||
if not worker.running:
|
||||
worker.stop()
|
||||
continue
|
||||
try:
|
||||
resp = worker.request({"action": "destroy_all"})
|
||||
destroyed += int(resp.get("destroyed") or 0)
|
||||
if worker.running:
|
||||
resp = worker.request({"action": "destroy_all"})
|
||||
destroyed += int(resp.get("destroyed") or 0)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
worker.stop()
|
||||
return {"destroyed": destroyed}
|
||||
|
||||
if action == "destroy_profile":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
with self._lock:
|
||||
worker = self._workers.get(profile)
|
||||
worker = self._workers.pop(profile, None)
|
||||
self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile}
|
||||
self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile}
|
||||
self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile}
|
||||
@@ -2075,9 +2120,12 @@ class BridgeBroker:
|
||||
|
||||
try:
|
||||
resp = worker.request({"action": "destroy_all"})
|
||||
return {"profile": profile, "destroyed": int(resp.get("destroyed") or 0)}
|
||||
destroyed = int(resp.get("destroyed") or 0)
|
||||
except Exception:
|
||||
return {"profile": profile, "destroyed": 0}
|
||||
destroyed = 0
|
||||
finally:
|
||||
worker.stop()
|
||||
return {"profile": profile, "destroyed": destroyed}
|
||||
|
||||
if action == "list":
|
||||
sessions: list[Any] = []
|
||||
|
||||
Reference in New Issue
Block a user