Account for full context tokens in compression (#908)

* Account for full context tokens in compression

* Fix group chat final context updates

---------

Co-authored-by: Codex <codex@openai.com>
This commit is contained in:
ekko
2026-05-21 19:40:52 +08:00
committed by GitHub
parent b2ec321990
commit 39ead94352
16 changed files with 730 additions and 35 deletions
@@ -602,6 +602,17 @@ class AgentPool:
def wrapped_compress_context(messages, system_message, **kwargs):
before_count = len(messages) if isinstance(messages, list) else 0
approx_tokens = kwargs.get("approx_tokens")
if not isinstance(approx_tokens, int) or approx_tokens <= 0:
approx_tokens = self._estimate_context_tokens(agent, messages, system_message)
print(
"[hermes_bridge] compression requested "
f"session={session_id} messages={before_count} "
f"tokens={approx_tokens if approx_tokens is not None else 'unknown'} "
f"focus={kwargs.get('focus_topic') or ''}",
file=sys.stderr,
flush=True,
)
request_id = uuid.uuid4().hex
response_queue: queue.Queue[dict[str, Any]] = queue.Queue(maxsize=1)
with self._lock:
@@ -610,7 +621,7 @@ class AgentPool:
"event": "bridge.compression.requested",
"request_id": request_id,
"message_count": before_count,
"approx_tokens": kwargs.get("approx_tokens"),
"approx_tokens": approx_tokens,
"focus_topic": kwargs.get("focus_topic"),
"messages": _jsonable(messages),
})
@@ -622,12 +633,14 @@ class AgentPool:
if not isinstance(compressed_messages, list):
raise RuntimeError("bridge compression response missing messages")
next_system_message = response.get("system_message", system_message)
result_approx_tokens = self._estimate_context_tokens(agent, compressed_messages, next_system_message)
self._append_event(session_id, {
"event": "bridge.compression.completed",
"request_id": request_id,
"message_count": before_count,
"result_messages": len(compressed_messages),
"approx_tokens": kwargs.get("approx_tokens"),
"approx_tokens": approx_tokens,
"result_approx_tokens": result_approx_tokens,
"compressed": True,
})
return compressed_messages, next_system_message
@@ -636,7 +649,7 @@ class AgentPool:
"event": "bridge.compression.failed",
"request_id": request_id,
"message_count": before_count,
"approx_tokens": kwargs.get("approx_tokens"),
"approx_tokens": approx_tokens,
"error": "bridge compression timed out",
})
raise RuntimeError("bridge compression timed out")
@@ -645,7 +658,7 @@ class AgentPool:
"event": "bridge.compression.failed",
"request_id": request_id,
"message_count": before_count,
"approx_tokens": kwargs.get("approx_tokens"),
"approx_tokens": approx_tokens,
"error": str(exc),
})
raise
@@ -655,6 +668,61 @@ class AgentPool:
agent._compress_context = wrapped_compress_context
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
try:
from agent.model_metadata import estimate_request_tokens_rough
except Exception:
return None
prompt = str(getattr(agent, "_cached_system_prompt", "") or "")
if not prompt:
try:
build_prompt = getattr(agent, "_build_system_prompt", None)
if callable(build_prompt):
prompt = str(build_prompt(system_message) or "")
except Exception:
prompt = str(system_message or "")
try:
estimate = estimate_request_tokens_rough(
messages if isinstance(messages, list) else [],
system_prompt=prompt,
tools=getattr(agent, "tools", None) or None,
)
return int(estimate) if isinstance(estimate, (int, float)) and estimate > 0 else None
except Exception:
return None
def estimate_context(
self,
session_id: str,
messages: list[dict[str, Any]] | None = None,
instructions: str | None = None,
profile: str | None = None,
model: str | None = None,
provider: str | None = None,
) -> dict[str, Any]:
session = self.get_or_create(session_id, profile=profile, model=model, provider=provider)
token_count = self._estimate_context_tokens(session.agent, messages or [], instructions)
tools = getattr(session.agent, "tools", None) or []
prompt = str(getattr(session.agent, "_cached_system_prompt", "") or "")
print(
"[hermes_bridge] context estimate "
f"session={session_id} profile={profile or 'default'} "
f"messages={len(messages or [])} system_prompt_chars={len(prompt)} "
f"tools={len(tools) if isinstance(tools, list) else 0} "
f"tokens={token_count if token_count is not None else 'unknown'}",
file=sys.stderr,
flush=True,
)
return {
"session_id": session_id,
"token_count": token_count,
"message_count": len(messages or []),
"tool_count": len(tools) if isinstance(tools, list) else 0,
"system_prompt_chars": len(prompt),
}
def respond_compression(
self,
request_id: str,
@@ -1329,6 +1397,20 @@ class BridgeServer:
return self.pool.get_result(record.run_id)
return {"run_id": record.run_id, "session_id": session_id, "status": record.status}
if action == "context_estimate":
session_id = str(req.get("session_id") or "").strip() or uuid.uuid4().hex
messages = req.get("messages") or req.get("conversation_history") or []
if not isinstance(messages, list):
raise ValueError("messages must be a list")
return self.pool.estimate_context(
session_id,
messages=messages,
instructions=req.get("instructions") or req.get("system_message"),
profile=req.get("profile"),
model=req.get("model"),
provider=req.get("provider"),
)
if action == "get_result":
return self.pool.get_result(str(req.get("run_id") or ""))
@@ -1870,6 +1952,10 @@ class BridgeBroker:
profile = self._normalize_profile(req.get("profile"))
return self._forward(profile, req)
if action == "context_estimate":
profile = self._normalize_profile(req.get("profile"))
return self._forward(profile, req)
if action in {"get_result", "get_output"}:
profile = self._profile_for_run(str(req.get("run_id") or ""))
return self._forward(profile, req)