fix compression context usage accounting (#924)
This commit is contained in:
@@ -668,30 +668,79 @@ class AgentPool:
|
||||
|
||||
agent._compress_context = wrapped_compress_context
|
||||
|
||||
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
|
||||
def _agent_system_prompt(self, agent: Any, system_message: Any = None) -> str:
|
||||
prompt = str(getattr(agent, "_cached_system_prompt", "") or "")
|
||||
if prompt:
|
||||
return prompt
|
||||
try:
|
||||
build_prompt = getattr(agent, "_build_system_prompt", None)
|
||||
if callable(build_prompt):
|
||||
return str(build_prompt(system_message) or "")
|
||||
except Exception:
|
||||
return str(system_message or "")
|
||||
return str(system_message or "")
|
||||
|
||||
def _agent_tool_names(self, tools: Any) -> list[str]:
|
||||
if not isinstance(tools, list):
|
||||
return []
|
||||
names: list[str] = []
|
||||
for tool in tools:
|
||||
name = ""
|
||||
if isinstance(tool, dict):
|
||||
function = tool.get("function")
|
||||
if isinstance(function, dict):
|
||||
name = str(function.get("name") or "")
|
||||
if not name:
|
||||
name = str(tool.get("name") or "")
|
||||
else:
|
||||
name = str(getattr(tool, "name", "") or "")
|
||||
if name:
|
||||
names.append(name)
|
||||
return names
|
||||
|
||||
def _estimate_context_info(self, agent: Any, messages: Any, system_message: Any = None) -> dict[str, Any]:
|
||||
try:
|
||||
from agent.model_metadata import estimate_request_tokens_rough
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
prompt = str(getattr(agent, "_cached_system_prompt", "") or "")
|
||||
if not prompt:
|
||||
try:
|
||||
build_prompt = getattr(agent, "_build_system_prompt", None)
|
||||
if callable(build_prompt):
|
||||
prompt = str(build_prompt(system_message) or "")
|
||||
except Exception:
|
||||
prompt = str(system_message or "")
|
||||
return {}
|
||||
|
||||
prompt = self._agent_system_prompt(agent, system_message)
|
||||
tools = getattr(agent, "tools", None) or []
|
||||
message_list = messages if isinstance(messages, list) else []
|
||||
try:
|
||||
estimate = estimate_request_tokens_rough(
|
||||
messages if isinstance(messages, list) else [],
|
||||
system_prompt=prompt,
|
||||
tools=getattr(agent, "tools", None) or None,
|
||||
)
|
||||
return int(estimate) if isinstance(estimate, (int, float)) and estimate > 0 else None
|
||||
token_count = estimate_request_tokens_rough(message_list, system_prompt=prompt, tools=tools or None)
|
||||
fixed_context_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=tools or None)
|
||||
system_prompt_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=None)
|
||||
tool_tokens = max(0, int(fixed_context_tokens or 0) - int(system_prompt_tokens or 0))
|
||||
return {
|
||||
"token_count": int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None,
|
||||
"fixed_context_tokens": int(fixed_context_tokens) if isinstance(fixed_context_tokens, (int, float)) and fixed_context_tokens >= 0 else None,
|
||||
"system_prompt_tokens": int(system_prompt_tokens) if isinstance(system_prompt_tokens, (int, float)) and system_prompt_tokens >= 0 else None,
|
||||
"tool_tokens": tool_tokens,
|
||||
"message_count": len(message_list),
|
||||
"tool_count": len(tools) if isinstance(tools, list) else 0,
|
||||
"tool_names": self._agent_tool_names(tools),
|
||||
"system_prompt_chars": len(prompt),
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
return {}
|
||||
|
||||
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
|
||||
token_count = self._estimate_context_info(agent, messages, system_message).get("token_count")
|
||||
return int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None
|
||||
|
||||
def _bridge_context_ready_event(self, session: AgentSession, instructions: str | None, profile: str | None) -> dict[str, Any]:
|
||||
info = self._estimate_context_info(session.agent, [], instructions)
|
||||
event = {
|
||||
"event": "bridge.context.ready",
|
||||
"session_id": session.session_id,
|
||||
"profile": profile or session.config.get("profile") or "default",
|
||||
"model": session.config.get("model"),
|
||||
"provider": session.config.get("provider"),
|
||||
**info,
|
||||
}
|
||||
session.config["context_info"] = event
|
||||
return event
|
||||
|
||||
def estimate_context(
|
||||
self,
|
||||
@@ -703,24 +752,23 @@ class AgentPool:
|
||||
provider: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
session = self.get_or_create(session_id, profile=profile, model=model, provider=provider)
|
||||
token_count = self._estimate_context_tokens(session.agent, messages or [], instructions)
|
||||
tools = getattr(session.agent, "tools", None) or []
|
||||
prompt = str(getattr(session.agent, "_cached_system_prompt", "") or "")
|
||||
context_info = self._estimate_context_info(session.agent, messages or [], instructions)
|
||||
print(
|
||||
"[hermes_bridge] context estimate "
|
||||
f"session={session_id} profile={profile or 'default'} "
|
||||
f"messages={len(messages or [])} system_prompt_chars={len(prompt)} "
|
||||
f"tools={len(tools) if isinstance(tools, list) else 0} "
|
||||
f"tokens={token_count if token_count is not None else 'unknown'}",
|
||||
f"messages={len(messages or [])} system_prompt_chars={context_info.get('system_prompt_chars') or 0} "
|
||||
f"tools={context_info.get('tool_count') or 0} "
|
||||
f"fixed_tokens={context_info.get('fixed_context_tokens') if context_info.get('fixed_context_tokens') is not None else 'unknown'} "
|
||||
f"tokens={context_info.get('token_count') if context_info.get('token_count') is not None else 'unknown'}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"token_count": token_count,
|
||||
"message_count": len(messages or []),
|
||||
"tool_count": len(tools) if isinstance(tools, list) else 0,
|
||||
"system_prompt_chars": len(prompt),
|
||||
"profile": profile or session.config.get("profile") or "default",
|
||||
"model": session.config.get("model"),
|
||||
"provider": session.config.get("provider"),
|
||||
**context_info,
|
||||
}
|
||||
|
||||
def respond_compression(
|
||||
@@ -1062,6 +1110,9 @@ class AgentPool:
|
||||
session.running = True
|
||||
session.current_run_id = run_id
|
||||
session.last_used_at = time.time()
|
||||
context_event = self._bridge_context_ready_event(session, instructions, profile)
|
||||
if context_event:
|
||||
record.events.append(_jsonable(context_event))
|
||||
|
||||
thread = threading.Thread(
|
||||
target=self._run_chat,
|
||||
|
||||
Reference in New Issue
Block a user