fix compression context usage accounting (#924)

This commit is contained in:
ekko
2026-05-22 09:46:50 +08:00
committed by GitHub
parent b5f0215beb
commit c3538a6b44
11 changed files with 454 additions and 61 deletions
@@ -668,30 +668,79 @@ class AgentPool:
agent._compress_context = wrapped_compress_context
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
def _agent_system_prompt(self, agent: Any, system_message: Any = None) -> str:
prompt = str(getattr(agent, "_cached_system_prompt", "") or "")
if prompt:
return prompt
try:
build_prompt = getattr(agent, "_build_system_prompt", None)
if callable(build_prompt):
return str(build_prompt(system_message) or "")
except Exception:
return str(system_message or "")
return str(system_message or "")
def _agent_tool_names(self, tools: Any) -> list[str]:
if not isinstance(tools, list):
return []
names: list[str] = []
for tool in tools:
name = ""
if isinstance(tool, dict):
function = tool.get("function")
if isinstance(function, dict):
name = str(function.get("name") or "")
if not name:
name = str(tool.get("name") or "")
else:
name = str(getattr(tool, "name", "") or "")
if name:
names.append(name)
return names
def _estimate_context_info(self, agent: Any, messages: Any, system_message: Any = None) -> dict[str, Any]:
try:
from agent.model_metadata import estimate_request_tokens_rough
except Exception:
return None
prompt = str(getattr(agent, "_cached_system_prompt", "") or "")
if not prompt:
try:
build_prompt = getattr(agent, "_build_system_prompt", None)
if callable(build_prompt):
prompt = str(build_prompt(system_message) or "")
except Exception:
prompt = str(system_message or "")
return {}
prompt = self._agent_system_prompt(agent, system_message)
tools = getattr(agent, "tools", None) or []
message_list = messages if isinstance(messages, list) else []
try:
estimate = estimate_request_tokens_rough(
messages if isinstance(messages, list) else [],
system_prompt=prompt,
tools=getattr(agent, "tools", None) or None,
)
return int(estimate) if isinstance(estimate, (int, float)) and estimate > 0 else None
token_count = estimate_request_tokens_rough(message_list, system_prompt=prompt, tools=tools or None)
fixed_context_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=tools or None)
system_prompt_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=None)
tool_tokens = max(0, int(fixed_context_tokens or 0) - int(system_prompt_tokens or 0))
return {
"token_count": int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None,
"fixed_context_tokens": int(fixed_context_tokens) if isinstance(fixed_context_tokens, (int, float)) and fixed_context_tokens >= 0 else None,
"system_prompt_tokens": int(system_prompt_tokens) if isinstance(system_prompt_tokens, (int, float)) and system_prompt_tokens >= 0 else None,
"tool_tokens": tool_tokens,
"message_count": len(message_list),
"tool_count": len(tools) if isinstance(tools, list) else 0,
"tool_names": self._agent_tool_names(tools),
"system_prompt_chars": len(prompt),
}
except Exception:
return None
return {}
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
token_count = self._estimate_context_info(agent, messages, system_message).get("token_count")
return int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None
def _bridge_context_ready_event(self, session: AgentSession, instructions: str | None, profile: str | None) -> dict[str, Any]:
info = self._estimate_context_info(session.agent, [], instructions)
event = {
"event": "bridge.context.ready",
"session_id": session.session_id,
"profile": profile or session.config.get("profile") or "default",
"model": session.config.get("model"),
"provider": session.config.get("provider"),
**info,
}
session.config["context_info"] = event
return event
def estimate_context(
self,
@@ -703,24 +752,23 @@ class AgentPool:
provider: str | None = None,
) -> dict[str, Any]:
session = self.get_or_create(session_id, profile=profile, model=model, provider=provider)
token_count = self._estimate_context_tokens(session.agent, messages or [], instructions)
tools = getattr(session.agent, "tools", None) or []
prompt = str(getattr(session.agent, "_cached_system_prompt", "") or "")
context_info = self._estimate_context_info(session.agent, messages or [], instructions)
print(
"[hermes_bridge] context estimate "
f"session={session_id} profile={profile or 'default'} "
f"messages={len(messages or [])} system_prompt_chars={len(prompt)} "
f"tools={len(tools) if isinstance(tools, list) else 0} "
f"tokens={token_count if token_count is not None else 'unknown'}",
f"messages={len(messages or [])} system_prompt_chars={context_info.get('system_prompt_chars') or 0} "
f"tools={context_info.get('tool_count') or 0} "
f"fixed_tokens={context_info.get('fixed_context_tokens') if context_info.get('fixed_context_tokens') is not None else 'unknown'} "
f"tokens={context_info.get('token_count') if context_info.get('token_count') is not None else 'unknown'}",
file=sys.stderr,
flush=True,
)
return {
"session_id": session_id,
"token_count": token_count,
"message_count": len(messages or []),
"tool_count": len(tools) if isinstance(tools, list) else 0,
"system_prompt_chars": len(prompt),
"profile": profile or session.config.get("profile") or "default",
"model": session.config.get("model"),
"provider": session.config.get("provider"),
**context_info,
}
def respond_compression(
@@ -1062,6 +1110,9 @@ class AgentPool:
session.running = True
session.current_run_id = run_id
session.last_used_at = time.time()
context_event = self._bridge_context_ready_event(session, instructions, profile)
if context_event:
record.events.append(_jsonable(context_event))
thread = threading.Thread(
target=self._run_chat,