Account for full context tokens in compression (#908)
* Account for full context tokens in compression * Fix group chat final context updates --------- Co-authored-by: Codex <codex@openai.com>
This commit is contained in:
@@ -602,6 +602,17 @@ class AgentPool:
|
||||
|
||||
def wrapped_compress_context(messages, system_message, **kwargs):
|
||||
before_count = len(messages) if isinstance(messages, list) else 0
|
||||
approx_tokens = kwargs.get("approx_tokens")
|
||||
if not isinstance(approx_tokens, int) or approx_tokens <= 0:
|
||||
approx_tokens = self._estimate_context_tokens(agent, messages, system_message)
|
||||
print(
|
||||
"[hermes_bridge] compression requested "
|
||||
f"session={session_id} messages={before_count} "
|
||||
f"tokens={approx_tokens if approx_tokens is not None else 'unknown'} "
|
||||
f"focus={kwargs.get('focus_topic') or ''}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
request_id = uuid.uuid4().hex
|
||||
response_queue: queue.Queue[dict[str, Any]] = queue.Queue(maxsize=1)
|
||||
with self._lock:
|
||||
@@ -610,7 +621,7 @@ class AgentPool:
|
||||
"event": "bridge.compression.requested",
|
||||
"request_id": request_id,
|
||||
"message_count": before_count,
|
||||
"approx_tokens": kwargs.get("approx_tokens"),
|
||||
"approx_tokens": approx_tokens,
|
||||
"focus_topic": kwargs.get("focus_topic"),
|
||||
"messages": _jsonable(messages),
|
||||
})
|
||||
@@ -622,12 +633,14 @@ class AgentPool:
|
||||
if not isinstance(compressed_messages, list):
|
||||
raise RuntimeError("bridge compression response missing messages")
|
||||
next_system_message = response.get("system_message", system_message)
|
||||
result_approx_tokens = self._estimate_context_tokens(agent, compressed_messages, next_system_message)
|
||||
self._append_event(session_id, {
|
||||
"event": "bridge.compression.completed",
|
||||
"request_id": request_id,
|
||||
"message_count": before_count,
|
||||
"result_messages": len(compressed_messages),
|
||||
"approx_tokens": kwargs.get("approx_tokens"),
|
||||
"approx_tokens": approx_tokens,
|
||||
"result_approx_tokens": result_approx_tokens,
|
||||
"compressed": True,
|
||||
})
|
||||
return compressed_messages, next_system_message
|
||||
@@ -636,7 +649,7 @@ class AgentPool:
|
||||
"event": "bridge.compression.failed",
|
||||
"request_id": request_id,
|
||||
"message_count": before_count,
|
||||
"approx_tokens": kwargs.get("approx_tokens"),
|
||||
"approx_tokens": approx_tokens,
|
||||
"error": "bridge compression timed out",
|
||||
})
|
||||
raise RuntimeError("bridge compression timed out")
|
||||
@@ -645,7 +658,7 @@ class AgentPool:
|
||||
"event": "bridge.compression.failed",
|
||||
"request_id": request_id,
|
||||
"message_count": before_count,
|
||||
"approx_tokens": kwargs.get("approx_tokens"),
|
||||
"approx_tokens": approx_tokens,
|
||||
"error": str(exc),
|
||||
})
|
||||
raise
|
||||
@@ -655,6 +668,61 @@ class AgentPool:
|
||||
|
||||
agent._compress_context = wrapped_compress_context
|
||||
|
||||
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
|
||||
try:
|
||||
from agent.model_metadata import estimate_request_tokens_rough
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
prompt = str(getattr(agent, "_cached_system_prompt", "") or "")
|
||||
if not prompt:
|
||||
try:
|
||||
build_prompt = getattr(agent, "_build_system_prompt", None)
|
||||
if callable(build_prompt):
|
||||
prompt = str(build_prompt(system_message) or "")
|
||||
except Exception:
|
||||
prompt = str(system_message or "")
|
||||
|
||||
try:
|
||||
estimate = estimate_request_tokens_rough(
|
||||
messages if isinstance(messages, list) else [],
|
||||
system_prompt=prompt,
|
||||
tools=getattr(agent, "tools", None) or None,
|
||||
)
|
||||
return int(estimate) if isinstance(estimate, (int, float)) and estimate > 0 else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def estimate_context(
|
||||
self,
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]] | None = None,
|
||||
instructions: str | None = None,
|
||||
profile: str | None = None,
|
||||
model: str | None = None,
|
||||
provider: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
session = self.get_or_create(session_id, profile=profile, model=model, provider=provider)
|
||||
token_count = self._estimate_context_tokens(session.agent, messages or [], instructions)
|
||||
tools = getattr(session.agent, "tools", None) or []
|
||||
prompt = str(getattr(session.agent, "_cached_system_prompt", "") or "")
|
||||
print(
|
||||
"[hermes_bridge] context estimate "
|
||||
f"session={session_id} profile={profile or 'default'} "
|
||||
f"messages={len(messages or [])} system_prompt_chars={len(prompt)} "
|
||||
f"tools={len(tools) if isinstance(tools, list) else 0} "
|
||||
f"tokens={token_count if token_count is not None else 'unknown'}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"token_count": token_count,
|
||||
"message_count": len(messages or []),
|
||||
"tool_count": len(tools) if isinstance(tools, list) else 0,
|
||||
"system_prompt_chars": len(prompt),
|
||||
}
|
||||
|
||||
def respond_compression(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -1329,6 +1397,20 @@ class BridgeServer:
|
||||
return self.pool.get_result(record.run_id)
|
||||
return {"run_id": record.run_id, "session_id": session_id, "status": record.status}
|
||||
|
||||
if action == "context_estimate":
|
||||
session_id = str(req.get("session_id") or "").strip() or uuid.uuid4().hex
|
||||
messages = req.get("messages") or req.get("conversation_history") or []
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError("messages must be a list")
|
||||
return self.pool.estimate_context(
|
||||
session_id,
|
||||
messages=messages,
|
||||
instructions=req.get("instructions") or req.get("system_message"),
|
||||
profile=req.get("profile"),
|
||||
model=req.get("model"),
|
||||
provider=req.get("provider"),
|
||||
)
|
||||
|
||||
if action == "get_result":
|
||||
return self.pool.get_result(str(req.get("run_id") or ""))
|
||||
|
||||
@@ -1870,6 +1952,10 @@ class BridgeBroker:
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
return self._forward(profile, req)
|
||||
|
||||
if action == "context_estimate":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
return self._forward(profile, req)
|
||||
|
||||
if action in {"get_result", "get_output"}:
|
||||
profile = self._profile_for_run(str(req.get("run_id") or ""))
|
||||
return self._forward(profile, req)
|
||||
|
||||
Reference in New Issue
Block a user