refactor chat run socket (#739)
This commit is contained in:
@@ -353,6 +353,7 @@ class AgentPool:
|
||||
self._sessions: dict[str, AgentSession] = {}
|
||||
self._runs: dict[str, RunRecord] = {}
|
||||
self._lock = threading.RLock()
|
||||
self._run_lock = threading.Lock()
|
||||
self._db = SessionDbHolder()
|
||||
self._approval_requests: dict[str, queue.Queue[str]] = {}
|
||||
self._compression_requests: dict[str, queue.Queue[dict[str, Any]]] = {}
|
||||
@@ -755,87 +756,88 @@ class AgentPool:
|
||||
return record
|
||||
|
||||
def _run_chat(self, session: AgentSession, record: RunRecord, message: Any, instructions: str | None = None, conversation_history: list[dict[str, Any]] | None = None, profile: str | None = None, force_compress: bool = False) -> None:
|
||||
def stream_callback(delta: str) -> None:
|
||||
with self._lock:
|
||||
record.deltas.append(str(delta))
|
||||
with self._run_lock:
|
||||
def stream_callback(delta: str) -> None:
|
||||
with self._lock:
|
||||
record.deltas.append(str(delta))
|
||||
|
||||
try:
|
||||
previous_approval_callback = None
|
||||
previous_exec_ask = os.environ.get("HERMES_EXEC_ASK")
|
||||
approval_session_token = None
|
||||
try:
|
||||
from tools.terminal_tool import _get_approval_callback, set_approval_callback
|
||||
from tools.approval import set_current_session_key
|
||||
|
||||
previous_approval_callback = _get_approval_callback()
|
||||
set_approval_callback(self._approval_callback(session.session_id))
|
||||
approval_session_token = set_current_session_key(session.session_id)
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
except Exception:
|
||||
previous_approval_callback = None
|
||||
self._prepersist_user_message(session, message, conversation_history, profile)
|
||||
if force_compress:
|
||||
compress = getattr(session.agent, "_compress_context", None)
|
||||
if callable(compress):
|
||||
compressed_history, compressed_system = compress(
|
||||
conversation_history if isinstance(conversation_history, list) else [],
|
||||
instructions,
|
||||
approx_tokens=None,
|
||||
focus_topic="debug_force_compress",
|
||||
)
|
||||
if isinstance(compressed_history, list):
|
||||
conversation_history = compressed_history
|
||||
if isinstance(compressed_system, str):
|
||||
instructions = compressed_system
|
||||
kwargs: dict[str, Any] = dict(
|
||||
task_id=session.session_id,
|
||||
stream_callback=stream_callback,
|
||||
)
|
||||
if instructions:
|
||||
kwargs["system_message"] = instructions
|
||||
if conversation_history is not None:
|
||||
kwargs["conversation_history"] = conversation_history
|
||||
result = session.agent.run_conversation(
|
||||
message,
|
||||
**kwargs,
|
||||
)
|
||||
result = _jsonable(result if isinstance(result, dict) else {"value": result})
|
||||
with session.lock:
|
||||
if isinstance(result.get("messages"), list):
|
||||
session.history = result["messages"]
|
||||
record.status = "interrupted" if result.get("interrupted") else "complete"
|
||||
record.result = result
|
||||
record.ended_at = time.time()
|
||||
session.running = False
|
||||
session.current_run_id = None
|
||||
session.last_used_at = time.time()
|
||||
except Exception as exc:
|
||||
with session.lock:
|
||||
record.status = "error"
|
||||
record.error = str(exc)
|
||||
record.result = {"error": str(exc), "traceback": traceback.format_exc()}
|
||||
record.ended_at = time.time()
|
||||
session.running = False
|
||||
session.current_run_id = None
|
||||
session.last_used_at = time.time()
|
||||
finally:
|
||||
try:
|
||||
from tools.terminal_tool import set_approval_callback
|
||||
|
||||
set_approval_callback(previous_approval_callback)
|
||||
except Exception:
|
||||
pass
|
||||
if approval_session_token is not None:
|
||||
previous_exec_ask = os.environ.get("HERMES_EXEC_ASK")
|
||||
approval_session_token = None
|
||||
try:
|
||||
from tools.approval import reset_current_session_key
|
||||
from tools.terminal_tool import _get_approval_callback, set_approval_callback
|
||||
from tools.approval import set_current_session_key
|
||||
|
||||
reset_current_session_key(approval_session_token)
|
||||
previous_approval_callback = _get_approval_callback()
|
||||
set_approval_callback(self._approval_callback(session.session_id))
|
||||
approval_session_token = set_current_session_key(session.session_id)
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
except Exception:
|
||||
previous_approval_callback = None
|
||||
self._prepersist_user_message(session, message, conversation_history, profile)
|
||||
if force_compress:
|
||||
compress = getattr(session.agent, "_compress_context", None)
|
||||
if callable(compress):
|
||||
compressed_history, compressed_system = compress(
|
||||
conversation_history if isinstance(conversation_history, list) else [],
|
||||
instructions,
|
||||
approx_tokens=None,
|
||||
focus_topic="debug_force_compress",
|
||||
)
|
||||
if isinstance(compressed_history, list):
|
||||
conversation_history = compressed_history
|
||||
if isinstance(compressed_system, str):
|
||||
instructions = compressed_system
|
||||
kwargs: dict[str, Any] = dict(
|
||||
task_id=session.session_id,
|
||||
stream_callback=stream_callback,
|
||||
)
|
||||
if instructions:
|
||||
kwargs["system_message"] = instructions
|
||||
if conversation_history is not None:
|
||||
kwargs["conversation_history"] = conversation_history
|
||||
result = session.agent.run_conversation(
|
||||
message,
|
||||
**kwargs,
|
||||
)
|
||||
result = _jsonable(result if isinstance(result, dict) else {"value": result})
|
||||
with session.lock:
|
||||
if isinstance(result.get("messages"), list):
|
||||
session.history = result["messages"]
|
||||
record.status = "interrupted" if result.get("interrupted") else "complete"
|
||||
record.result = result
|
||||
record.ended_at = time.time()
|
||||
session.running = False
|
||||
session.current_run_id = None
|
||||
session.last_used_at = time.time()
|
||||
except Exception as exc:
|
||||
with session.lock:
|
||||
record.status = "error"
|
||||
record.error = str(exc)
|
||||
record.result = {"error": str(exc), "traceback": traceback.format_exc()}
|
||||
record.ended_at = time.time()
|
||||
session.running = False
|
||||
session.current_run_id = None
|
||||
session.last_used_at = time.time()
|
||||
finally:
|
||||
try:
|
||||
from tools.terminal_tool import set_approval_callback
|
||||
|
||||
set_approval_callback(previous_approval_callback)
|
||||
except Exception:
|
||||
pass
|
||||
if previous_exec_ask is None:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
else:
|
||||
os.environ["HERMES_EXEC_ASK"] = previous_exec_ask
|
||||
if approval_session_token is not None:
|
||||
try:
|
||||
from tools.approval import reset_current_session_key
|
||||
|
||||
reset_current_session_key(approval_session_token)
|
||||
except Exception:
|
||||
pass
|
||||
if previous_exec_ask is None:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
else:
|
||||
os.environ["HERMES_EXEC_ASK"] = previous_exec_ask
|
||||
|
||||
def interrupt(self, session_id: str, message: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
@@ -845,7 +847,15 @@ class AgentPool:
|
||||
if not hasattr(session.agent, "interrupt"):
|
||||
raise RuntimeError("agent does not support interrupt")
|
||||
session.agent.interrupt(message)
|
||||
return {"status": "interrupted", "session_id": session_id}
|
||||
deadline = time.time() + 10.0
|
||||
synced = False
|
||||
while time.time() < deadline:
|
||||
with session.lock:
|
||||
if not session.running:
|
||||
synced = True
|
||||
break
|
||||
time.sleep(0.05)
|
||||
return {"status": "interrupted", "session_id": session_id, "synced": synced}
|
||||
|
||||
def steer(self, session_id: str, text: str) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
|
||||
Reference in New Issue
Block a user