diff --git a/packages/client/src/stores/hermes/chat.ts b/packages/client/src/stores/hermes/chat.ts index 846c5be..35fe2ea 100644 --- a/packages/client/src/stores/hermes/chat.ts +++ b/packages/client/src/stores/hermes/chat.ts @@ -569,14 +569,16 @@ export const useChatStore = defineStore('chat', () => { compressed: null, }) } else if (e.event === 'compression.completed') { + const afterTokens = e.contextTokens || e.afterTokens || 0 setCompressionState({ compressing: false, messageCount: e.totalMessages || 0, beforeTokens: e.beforeTokens || 0, - afterTokens: e.afterTokens || 0, + afterTokens, compressed: e.compressed ?? false, error: e.error, }) + if (e.contextTokens != null) activeSession.value!.contextTokens = e.contextTokens } else if (e.event === 'abort.started') { setAbortState({ aborting: true, synced: null }) } else if (e.event === 'abort.completed') { @@ -1073,14 +1075,19 @@ export const useChatStore = defineStore('chat', () => { } case 'compression.completed': { + const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0 setCompressionState({ compressing: false, messageCount: (evt as any).totalMessages || 0, beforeTokens: (evt as any).beforeTokens || 0, - afterTokens: (evt as any).afterTokens || 0, + afterTokens, compressed: (evt as any).compressed ?? false, error: (evt as any).error, }) + if ((evt as any).contextTokens != null) { + const target = sessions.value.find(s => s.id === sid) + if (target) target.contextTokens = (evt as any).contextTokens + } // Auto-clear after 5s setTimeout(() => { if (compressionState.value && !compressionState.value.compressing) { @@ -1520,14 +1527,19 @@ export const useChatStore = defineStore('chat', () => { } case 'compression.completed': { + const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0 setCompressionState({ compressing: false, messageCount: (evt as any).totalMessages || 0, beforeTokens: (evt as any).beforeTokens || 0, - afterTokens: (evt as any).afterTokens || 0, + afterTokens, compressed: (evt as any).compressed ?? false, error: (evt as any).error, }) + if ((evt as any).contextTokens != null) { + const target = sessions.value.find(s => s.id === sid) + if (target) target.contextTokens = (evt as any).contextTokens + } setTimeout(() => { if (compressionState.value && !compressionState.value.compressing) { setCompressionState(null) diff --git a/packages/server/src/services/hermes/agent-bridge/client.ts b/packages/server/src/services/hermes/agent-bridge/client.ts index 6965682..70aabd7 100644 --- a/packages/server/src/services/hermes/agent-bridge/client.ts +++ b/packages/server/src/services/hermes/agent-bridge/client.ts @@ -93,9 +93,16 @@ export interface AgentBridgeRunResult extends AgentBridgeResponse { export interface AgentBridgeContextEstimate extends AgentBridgeResponse { session_id: string token_count?: number | null + fixed_context_tokens?: number | null + system_prompt_tokens?: number | null + tool_tokens?: number | null message_count: number tool_count: number + tool_names?: string[] system_prompt_chars: number + profile?: string + model?: string + provider?: string } export interface AgentBridgeCommandResult extends AgentBridgeResponse { diff --git a/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py b/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py index ba7bd75..63a38c6 100755 --- a/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py +++ b/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py @@ -668,30 +668,79 @@ class AgentPool: agent._compress_context = wrapped_compress_context - def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None: + def _agent_system_prompt(self, agent: Any, system_message: Any = None) -> str: + prompt = str(getattr(agent, "_cached_system_prompt", "") or "") + if prompt: + return prompt + try: + build_prompt = getattr(agent, "_build_system_prompt", None) + if callable(build_prompt): + return str(build_prompt(system_message) or "") + except Exception: + return str(system_message or "") + return str(system_message or "") + + def _agent_tool_names(self, tools: Any) -> list[str]: + if not isinstance(tools, list): + return [] + names: list[str] = [] + for tool in tools: + name = "" + if isinstance(tool, dict): + function = tool.get("function") + if isinstance(function, dict): + name = str(function.get("name") or "") + if not name: + name = str(tool.get("name") or "") + else: + name = str(getattr(tool, "name", "") or "") + if name: + names.append(name) + return names + + def _estimate_context_info(self, agent: Any, messages: Any, system_message: Any = None) -> dict[str, Any]: try: from agent.model_metadata import estimate_request_tokens_rough except Exception: - return None - - prompt = str(getattr(agent, "_cached_system_prompt", "") or "") - if not prompt: - try: - build_prompt = getattr(agent, "_build_system_prompt", None) - if callable(build_prompt): - prompt = str(build_prompt(system_message) or "") - except Exception: - prompt = str(system_message or "") + return {} + prompt = self._agent_system_prompt(agent, system_message) + tools = getattr(agent, "tools", None) or [] + message_list = messages if isinstance(messages, list) else [] try: - estimate = estimate_request_tokens_rough( - messages if isinstance(messages, list) else [], - system_prompt=prompt, - tools=getattr(agent, "tools", None) or None, - ) - return int(estimate) if isinstance(estimate, (int, float)) and estimate > 0 else None + token_count = estimate_request_tokens_rough(message_list, system_prompt=prompt, tools=tools or None) + fixed_context_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=tools or None) + system_prompt_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=None) + tool_tokens = max(0, int(fixed_context_tokens or 0) - int(system_prompt_tokens or 0)) + return { + "token_count": int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None, + "fixed_context_tokens": int(fixed_context_tokens) if isinstance(fixed_context_tokens, (int, float)) and fixed_context_tokens >= 0 else None, + "system_prompt_tokens": int(system_prompt_tokens) if isinstance(system_prompt_tokens, (int, float)) and system_prompt_tokens >= 0 else None, + "tool_tokens": tool_tokens, + "message_count": len(message_list), + "tool_count": len(tools) if isinstance(tools, list) else 0, + "tool_names": self._agent_tool_names(tools), + "system_prompt_chars": len(prompt), + } except Exception: - return None + return {} + + def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None: + token_count = self._estimate_context_info(agent, messages, system_message).get("token_count") + return int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None + + def _bridge_context_ready_event(self, session: AgentSession, instructions: str | None, profile: str | None) -> dict[str, Any]: + info = self._estimate_context_info(session.agent, [], instructions) + event = { + "event": "bridge.context.ready", + "session_id": session.session_id, + "profile": profile or session.config.get("profile") or "default", + "model": session.config.get("model"), + "provider": session.config.get("provider"), + **info, + } + session.config["context_info"] = event + return event def estimate_context( self, @@ -703,24 +752,23 @@ class AgentPool: provider: str | None = None, ) -> dict[str, Any]: session = self.get_or_create(session_id, profile=profile, model=model, provider=provider) - token_count = self._estimate_context_tokens(session.agent, messages or [], instructions) - tools = getattr(session.agent, "tools", None) or [] - prompt = str(getattr(session.agent, "_cached_system_prompt", "") or "") + context_info = self._estimate_context_info(session.agent, messages or [], instructions) print( "[hermes_bridge] context estimate " f"session={session_id} profile={profile or 'default'} " - f"messages={len(messages or [])} system_prompt_chars={len(prompt)} " - f"tools={len(tools) if isinstance(tools, list) else 0} " - f"tokens={token_count if token_count is not None else 'unknown'}", + f"messages={len(messages or [])} system_prompt_chars={context_info.get('system_prompt_chars') or 0} " + f"tools={context_info.get('tool_count') or 0} " + f"fixed_tokens={context_info.get('fixed_context_tokens') if context_info.get('fixed_context_tokens') is not None else 'unknown'} " + f"tokens={context_info.get('token_count') if context_info.get('token_count') is not None else 'unknown'}", file=sys.stderr, flush=True, ) return { "session_id": session_id, - "token_count": token_count, - "message_count": len(messages or []), - "tool_count": len(tools) if isinstance(tools, list) else 0, - "system_prompt_chars": len(prompt), + "profile": profile or session.config.get("profile") or "default", + "model": session.config.get("model"), + "provider": session.config.get("provider"), + **context_info, } def respond_compression( @@ -1062,6 +1110,9 @@ class AgentPool: session.running = True session.current_run_id = run_id session.last_used_at = time.time() + context_event = self._bridge_context_ready_event(session, instructions, profile) + if context_event: + record.events.append(_jsonable(context_event)) thread = threading.Thread( target=self._run_chat, diff --git a/packages/server/src/services/hermes/run-chat/compression.ts b/packages/server/src/services/hermes/run-chat/compression.ts index 9e8b8a0..ec7d262 100644 --- a/packages/server/src/services/hermes/run-chat/compression.ts +++ b/packages/server/src/services/hermes/run-chat/compression.ts @@ -13,7 +13,7 @@ import { getModelContextLength } from '../model-context' import { readConfigYamlForProfile } from '../../config-helpers' import { logger } from '../../logger' import { bridgeLogger } from '../../logger' -import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage' +import { calcAndUpdateUsage, estimateUsageTokensFromMessages, updateMessageContextTokenUsage } from './usage' import { isAssistantMessageSendable } from './message-format' import type { ChatMessage, CompressionConfig as CompressorConfig } from '../../../lib/context-compressor' import type { SessionState, BridgeCompressionResult } from './types' @@ -69,6 +69,23 @@ function buildSnapshotHistory( ] } +export async function buildSnapshotAwareHistory( + sessionId: string, + profile: string, + history: ChatMessage[], + modelContext: { model?: string | null; provider?: string | null } = {}, +): Promise { + const snapshot = getCompressionSnapshot(sessionId) + if (!snapshot) return history + const contextLength = getModelContextLength({ + profile, + model: modelContext.model, + provider: modelContext.provider, + }) + const compressionConfig = await getRunChatCompressionConfig(profile, contextLength) + return buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history +} + function clampRatio(value: unknown, fallback: number, min: number, max: number): number { const n = typeof value === 'number' && Number.isFinite(value) ? value : fallback return Math.min(max, Math.max(min, n)) @@ -343,21 +360,26 @@ export async function compressHistory( provider: modelContext.provider || session?.provider, }) const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit) - const compressedMeta = { + const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens + const compressedMeta: any = { event: 'compression.completed' as const, compressed: result.meta.compressed, llmCompressed: result.meta.llmCompressed, totalMessages: result.meta.totalMessages, resultMessages: result.messages.length, beforeTokens: totalTokens, - afterTokens: afterTokens.inputTokens + afterTokens.outputTokens, + afterTokens: compressedAfterTokens, summaryTokens: result.meta.summaryTokenEstimate, verbatimCount: result.meta.verbatimCount, compressedStartIndex: result.meta.compressedStartIndex, } replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta) logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)', - sessionId, result.messages.length, afterTokens.inputTokens + afterTokens.outputTokens, totalTokens) + sessionId, result.messages.length, compressedAfterTokens, totalTokens) + const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedAfterTokens, afterTokens) + if (compressedContextTokens != null) { + compressedMeta.contextTokens = compressedContextTokens + } emit('compression.completed', compressedMeta) const compressed = result.messages.map(m => { diff --git a/packages/server/src/services/hermes/run-chat/handle-bridge-run.ts b/packages/server/src/services/hermes/run-chat/handle-bridge-run.ts index 94c9a5f..72db4ff 100644 --- a/packages/server/src/services/hermes/run-chat/handle-bridge-run.ts +++ b/packages/server/src/services/hermes/run-chat/handle-bridge-run.ts @@ -8,10 +8,17 @@ import { getSystemPrompt } from '../../../lib/llm-prompt' import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store' import { updateUsage } from '../../../db/hermes/usage-store' import { logger, bridgeLogger } from '../../logger' -import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge' +import { AgentBridgeClient, type AgentBridgeContextEstimate, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge' import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks' -import { buildCompressedHistory, buildDbHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression' -import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage' +import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression' +import { + calcAndUpdateUsage, + contextTokensWithCachedOverhead, + estimateUsageTokensFromMessages, + getCachedBridgeContextOverhead, + updateContextTokenUsage, + updateMessageContextTokenUsage, +} from './usage' import { flushBridgePendingToDb, ensureOpenBridgeAssistantMessage, @@ -62,6 +69,28 @@ export function bridgeTerminalError(chunk: Pick= 0 + ? Math.floor(value) + : undefined +} + +function cacheBridgeContext(state: SessionState, data: Record | AgentBridgeContextEstimate) { + const fixedContextTokens = finiteToken(data.fixed_context_tokens) + if (fixedContextTokens == null) return + state.bridgeContext = { + fixedContextTokens, + systemPromptTokens: finiteToken(data.system_prompt_tokens), + toolTokens: finiteToken(data.tool_tokens), + systemPromptChars: finiteToken(data.system_prompt_chars), + toolCount: finiteToken(data.tool_count), + toolNames: Array.isArray(data.tool_names) ? data.tool_names.map(String) : undefined, + profile: typeof data.profile === 'string' ? data.profile : state.bridgeContext?.profile, + model: typeof data.model === 'string' ? data.model : state.bridgeContext?.model, + provider: typeof data.provider === 'string' ? data.provider : state.bridgeContext?.provider, + } +} + export async function handleBridgeRun( nsp: ReturnType, socket: Socket, @@ -168,6 +197,11 @@ export async function handleBridgeRun( sessionMap, { model: resolvedModel, provider: resolvedProvider }, async (messages) => { + const cachedOverhead = getCachedBridgeContextOverhead(state) + if (cachedOverhead != null) { + const messageUsage = estimateUsageTokensFromMessages(messages) + return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens + } const estimate = await bridge.contextEstimate( session_id, messages, @@ -175,6 +209,7 @@ export async function handleBridgeRun( profile, { model: resolvedModel, provider: resolvedProvider }, ) + cacheBridgeContext(state, estimate) bridgeLogger.info({ sessionId: session_id, profile, @@ -183,6 +218,7 @@ export async function handleBridgeRun( messages: estimate.message_count, toolCount: estimate.tool_count, systemPromptChars: estimate.system_prompt_chars, + fixedContextTokens: estimate.fixed_context_tokens, fullContextTokens: estimate.token_count, }, '[chat-run-socket] full context estimate') return estimate.token_count @@ -308,7 +344,35 @@ async function refreshFinalContextUsage(args: { bridge: AgentBridgeClient }): Promise { try { - const finalHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false }) + const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false }) + const finalHistory = await buildSnapshotAwareHistory( + args.sessionId, + args.profile, + dbHistory, + { model: args.model, provider: args.provider }, + ) + const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory) + const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens + if (getCachedBridgeContextOverhead(args.state) != null) { + const contextTokens = updateMessageContextTokenUsage( + args.sessionId, + args.state, + args.emit, + finalMessageTokens, + args.usage, + ) + bridgeLogger.info({ + sessionId: args.sessionId, + profile: args.profile, + model: args.model, + provider: args.provider, + messages: finalHistory.length, + fixedContextTokens: args.state.bridgeContext?.fixedContextTokens, + messageTokens: finalMessageTokens, + fullContextTokens: contextTokens, + }, '[chat-run-socket] final cached context estimate') + return contextTokens + } const estimate = await args.bridge.contextEstimate( args.sessionId, finalHistory, @@ -316,18 +380,13 @@ async function refreshFinalContextUsage(args: { args.profile, { model: args.model ?? undefined, provider: args.provider ?? undefined }, ) + cacheBridgeContext(args.state, estimate) const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0 ? Math.floor(estimate.token_count) : undefined if (contextTokens == null) return args.state.contextTokens - args.state.contextTokens = contextTokens - args.emit('usage.updated', { - event: 'usage.updated', - inputTokens: args.usage.inputTokens, - outputTokens: args.usage.outputTokens, - contextTokens, - }) + updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage) bridgeLogger.info({ sessionId: args.sessionId, profile: args.profile, @@ -378,7 +437,17 @@ async function applyBridgeChunkAsync( for (const ev of chunk.events || []) { const evType = ev.event as string | undefined - if (evType === 'tool.started') { + if (evType === 'bridge.context.ready') { + cacheBridgeContext(state, ev) + const usage = await calcAndUpdateUsage(sessionId, state, emit) + updateMessageContextTokenUsage( + sessionId, + state, + emit, + usage.inputTokens + usage.outputTokens, + usage, + ) + } else if (evType === 'tool.started') { flushBridgePendingToDb(state, sessionId, runMarker) const toolName = (ev.tool_name as string) || '' const args = ev.args as Record | undefined @@ -498,6 +567,11 @@ async function applyBridgeChunkAsync( const compressionResult = ev.request_id ? state.bridgeCompressionResults?.[String(ev.request_id)] : undefined + const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens) + const messageAfterTokens = finiteToken(compressionResult?.afterTokens) + const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null + ? contextTokensWithCachedOverhead(state, messageAfterTokens) + : bridgeAfterContextTokens ?? messageAfterTokens const payload = { event: 'compression.completed', run_id: chunk.run_id, @@ -507,9 +581,8 @@ async function applyBridgeChunkAsync( totalMessages: compressionResult?.beforeMessages ?? ev.message_count, resultMessages: compressionResult?.resultMessages ?? ev.result_messages, beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens, - afterTokens: typeof ev.result_approx_tokens === 'number' && Number.isFinite(ev.result_approx_tokens) && ev.result_approx_tokens > 0 - ? ev.result_approx_tokens - : compressionResult?.afterTokens, + afterTokens: messageAfterTokens ?? bridgeAfterContextTokens, + contextTokens: afterContextTokens, summaryTokens: compressionResult?.summaryTokens, verbatimCount: compressionResult?.verbatimCount, compressedStartIndex: compressionResult?.compressedStartIndex, @@ -520,7 +593,12 @@ async function applyBridgeChunkAsync( } replaceState(sessionMap, sessionId, 'compression.completed', payload) emit('compression.completed', payload) - await calcAndUpdateUsage(sessionId, state, emit) + const usage = await calcAndUpdateUsage(sessionId, state, emit) + if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) { + updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage) + } else { + updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage) + } } else if (evType === 'bridge.compression.failed') { const payload = { event: 'compression.completed', diff --git a/packages/server/src/services/hermes/run-chat/session-command.ts b/packages/server/src/services/hermes/run-chat/session-command.ts index ed58663..794fa61 100644 --- a/packages/server/src/services/hermes/run-chat/session-command.ts +++ b/packages/server/src/services/hermes/run-chat/session-command.ts @@ -5,7 +5,7 @@ import type { AgentBridgeClient } from '../agent-bridge' import { flushBridgePendingToDb } from './bridge-message' import { buildDbHistory, estimateSnapshotAwareHistoryUsage, forceCompressBridgeHistory, getOrCreateSession, replaceState } from './compression' import { handleAbort } from './abort' -import { calcAndUpdateUsage } from './usage' +import { calcAndUpdateUsage, contextTokensWithCachedOverhead, updateMessageContextTokenUsage } from './usage' import type { ContentBlock, QueuedRun, SessionState } from './types' type CommandName = @@ -232,10 +232,11 @@ export async function handleSessionCommand( try { const history = await buildDbHistory(sessionId, { excludeLastUser: true }) const usageEstimate = estimateSnapshotAwareHistoryUsage(sessionId, history) + const beforeContextTokens = contextTokensWithCachedOverhead(state, usageEstimate.tokenCount) emit('compression.started', { event: 'compression.started', message_count: usageEstimate.messageCount, - token_count: usageEstimate.tokenCount, + token_count: beforeContextTokens, source: 'command', }) const result = await forceCompressBridgeHistory( @@ -244,27 +245,32 @@ export async function handleSessionCommand( [], ) state.bridgeCompressionResults = state.bridgeCompressionResults || {} - await calcAndUpdateUsage(sessionId, state, emit) + const usage = await calcAndUpdateUsage(sessionId, state, emit) + const afterContextTokens = contextTokensWithCachedOverhead(state, result.afterTokens) emit('compression.completed', { event: 'compression.completed', compressed: result.compressed, llmCompressed: result.llmCompressed, totalMessages: result.beforeMessages, resultMessages: result.resultMessages, - beforeTokens: result.beforeTokens, + beforeTokens: beforeContextTokens, afterTokens: result.afterTokens, summaryTokens: result.summaryTokens, verbatimCount: result.verbatimCount, compressedStartIndex: result.compressedStartIndex, + contextTokens: afterContextTokens, source: 'command', }) + updateMessageContextTokenUsage(sessionId, state, emit, result.afterTokens, usage) emitCommand({ action: 'compress', - message: `Compression completed: ${result.beforeMessages} -> ${result.resultMessages} messages, ${result.beforeTokens} -> ${result.afterTokens} tokens.`, + message: `Compression completed: ${result.beforeMessages} -> ${result.resultMessages} messages, ${beforeContextTokens} -> ${afterContextTokens} tokens.`, beforeMessages: result.beforeMessages, resultMessages: result.resultMessages, - beforeTokens: result.beforeTokens, - afterTokens: result.afterTokens, + beforeTokens: beforeContextTokens, + afterTokens: afterContextTokens, + messageBeforeTokens: result.beforeTokens, + messageAfterTokens: result.afterTokens, compressed: result.compressed, }) } catch (err) { diff --git a/packages/server/src/services/hermes/run-chat/types.ts b/packages/server/src/services/hermes/run-chat/types.ts index a843b22..2a87f96 100644 --- a/packages/server/src/services/hermes/run-chat/types.ts +++ b/packages/server/src/services/hermes/run-chat/types.ts @@ -47,6 +47,7 @@ export interface SessionState { inputTokens?: number outputTokens?: number contextTokens?: number + bridgeContext?: BridgeContextState isAborting?: boolean queue: QueuedRun[] responseRun?: ResponseRunState @@ -72,6 +73,18 @@ export interface ResponseRunState { toolCalls: Map } +export interface BridgeContextState { + fixedContextTokens?: number + systemPromptTokens?: number + toolTokens?: number + systemPromptChars?: number + toolCount?: number + toolNames?: string[] + profile?: string + model?: string + provider?: string +} + export type ChatRunSource = 'api_server' | 'cli' export interface BridgeCompressionResult { diff --git a/packages/server/src/services/hermes/run-chat/usage.ts b/packages/server/src/services/hermes/run-chat/usage.ts index cb152ce..65222d8 100644 --- a/packages/server/src/services/hermes/run-chat/usage.ts +++ b/packages/server/src/services/hermes/run-chat/usage.ts @@ -78,3 +78,60 @@ export async function calcAndUpdateUsage( return { inputTokens: 0, outputTokens: 0 } } } + +export function updateContextTokenUsage( + sid: string, + state: SessionState, + emit: (event: string, payload: any) => void, + contextTokens: number | null | undefined, + usage?: { inputTokens: number; outputTokens: number }, +): number | undefined { + if (typeof contextTokens !== 'number' || !Number.isFinite(contextTokens) || contextTokens < 0) { + return state.contextTokens + } + const normalizedContextTokens = Math.floor(contextTokens) + state.contextTokens = normalizedContextTokens + emit('usage.updated', { + event: 'usage.updated', + session_id: sid, + inputTokens: usage?.inputTokens ?? state.inputTokens ?? 0, + outputTokens: usage?.outputTokens ?? state.outputTokens ?? 0, + contextTokens: normalizedContextTokens, + }) + return normalizedContextTokens +} + +export function getCachedBridgeContextOverhead(state: SessionState): number | undefined { + const fixedContextTokens = state.bridgeContext?.fixedContextTokens + if (typeof fixedContextTokens !== 'number' || !Number.isFinite(fixedContextTokens) || fixedContextTokens < 0) { + return undefined + } + return Math.floor(fixedContextTokens) +} + +export function contextTokensWithCachedOverhead(state: SessionState, messageTokens: number): number { + const normalizedMessageTokens = Math.max(0, Math.floor(messageTokens)) + const fixedContextTokens = getCachedBridgeContextOverhead(state) + return fixedContextTokens == null + ? normalizedMessageTokens + : fixedContextTokens + normalizedMessageTokens +} + +export function updateMessageContextTokenUsage( + sid: string, + state: SessionState, + emit: (event: string, payload: any) => void, + messageTokens: number | null | undefined, + usage?: { inputTokens: number; outputTokens: number }, +): number | undefined { + if (typeof messageTokens !== 'number' || !Number.isFinite(messageTokens) || messageTokens < 0) { + return state.contextTokens + } + return updateContextTokenUsage( + sid, + state, + emit, + contextTokensWithCachedOverhead(state, messageTokens), + usage, + ) +} diff --git a/tests/server/run-chat-bridge-final-context.test.ts b/tests/server/run-chat-bridge-final-context.test.ts index 5137796..c9ddd1c 100644 --- a/tests/server/run-chat-bridge-final-context.test.ts +++ b/tests/server/run-chat-bridge-final-context.test.ts @@ -9,11 +9,26 @@ const updateSessionStatsMock = vi.fn() const updateUsageMock = vi.fn() const buildCompressedHistoryMock = vi.fn() const buildDbHistoryMock = vi.fn() +const buildSnapshotAwareHistoryMock = vi.fn(async (_sessionId: string, _profile: string, history: any[]) => history) const pushStateMock = vi.fn() const replaceStateMock = vi.fn() const forceCompressBridgeHistoryMock = vi.fn() const calcAndUpdateUsageMock = vi.fn() const estimateUsageTokensFromMessagesMock = vi.fn() +const updateContextTokenUsageMock = vi.fn((sid: string, state: any, emit: any, contextTokens: number, usage?: { inputTokens: number; outputTokens: number }) => { + state.contextTokens = contextTokens + emit('usage.updated', { + event: 'usage.updated', + session_id: sid, + inputTokens: usage?.inputTokens ?? state.inputTokens ?? 0, + outputTokens: usage?.outputTokens ?? state.outputTokens ?? 0, + contextTokens, + }) + return contextTokens +}) +const getCachedBridgeContextOverheadMock = vi.fn(() => undefined) +const contextTokensWithCachedOverheadMock = vi.fn((_state: any, messageTokens: number) => messageTokens) +const updateMessageContextTokenUsageMock = vi.fn((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, state, emit, messageTokens, usage)) const flushBridgePendingToDbMock = vi.fn() const ensureOpenBridgeAssistantMessageMock = vi.fn() const syncBridgeReasoningToMessageMock = vi.fn() @@ -45,6 +60,7 @@ vi.mock('../../packages/server/src/services/logger', () => ({ vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () => ({ buildCompressedHistory: buildCompressedHistoryMock, buildDbHistory: buildDbHistoryMock, + buildSnapshotAwareHistory: buildSnapshotAwareHistoryMock, pushState: pushStateMock, replaceState: replaceStateMock, forceCompressBridgeHistory: forceCompressBridgeHistoryMock, @@ -53,6 +69,10 @@ vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () => vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({ calcAndUpdateUsage: calcAndUpdateUsageMock, estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock, + getCachedBridgeContextOverhead: getCachedBridgeContextOverheadMock, + contextTokensWithCachedOverhead: contextTokensWithCachedOverheadMock, + updateContextTokenUsage: updateContextTokenUsageMock, + updateMessageContextTokenUsage: updateMessageContextTokenUsageMock, })) vi.mock('../../packages/server/src/services/hermes/run-chat/bridge-message', () => ({ @@ -103,7 +123,11 @@ describe('bridge run final context usage', () => { { role: 'user', content: 'hello' }, { role: 'assistant', content: 'done' }, ]) + buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history) calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 }) + estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 }) + getCachedBridgeContextOverheadMock.mockReturnValue(undefined) + contextTokensWithCachedOverheadMock.mockImplementation((_state: any, messageTokens: number) => messageTokens) }) it('refreshes full context tokens when a bridge run completes', async () => { @@ -161,6 +185,50 @@ describe('bridge run final context usage', () => { })) }) + it('uses cached fixed context instead of bridge estimate when available', async () => { + const emit = vi.fn() + const nsp = makeNamespace(emit) + const socket = makeSocket() + const state = makeState() + state.bridgeContext = { fixedContextTokens: 20_000 } + const sessionMap = new Map([['session-1', state]]) + getCachedBridgeContextOverheadMock.mockReturnValue(20_000) + updateMessageContextTokenUsageMock.mockImplementation((sid: string, targetState: any, targetEmit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, targetState, targetEmit, 20_000 + messageTokens, usage)) + const bridge = { + chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }), + contextEstimate: vi.fn(), + streamOutput: vi.fn(async function* () { + yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' } + }), + } as any + + const { handleBridgeRun } = await import('../../packages/server/src/services/hermes/run-chat/handle-bridge-run') + await handleBridgeRun( + nsp, + socket, + { input: 'hello', session_id: 'session-1' }, + 'default', + sessionMap, + bridge, + false, + vi.fn(), + vi.fn(), + ) + + expect(bridge.contextEstimate).not.toHaveBeenCalled() + expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith( + 'session-1', + state, + expect.any(Function), + 18, + { inputTokens: 11, outputTokens: 7 }, + ) + expect(state.contextTokens).toBe(20_018) + expect(emit).toHaveBeenCalledWith('run.completed', expect.objectContaining({ + contextTokens: 20_018, + })) + }) + it('refreshes full context tokens when a bridge run fails', async () => { const emit = vi.fn() const nsp = makeNamespace(emit) diff --git a/tests/server/run-chat-compression.test.ts b/tests/server/run-chat-compression.test.ts index a702945..256c73f 100644 --- a/tests/server/run-chat-compression.test.ts +++ b/tests/server/run-chat-compression.test.ts @@ -6,6 +6,17 @@ const getCompressionSnapshotMock = vi.fn() const getModelContextLengthMock = vi.fn() const calcAndUpdateUsageMock = vi.fn() const estimateUsageTokensFromMessagesMock = vi.fn() +const updateMessageContextTokenUsageMock = vi.fn((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => { + state.contextTokens = messageTokens + emit('usage.updated', { + event: 'usage.updated', + session_id: sid, + inputTokens: usage?.inputTokens ?? state.inputTokens ?? 0, + outputTokens: usage?.outputTokens ?? state.outputTokens ?? 0, + contextTokens: messageTokens, + }) + return messageTokens +}) const compressorCompressMock = vi.fn() const readConfigYamlForProfileMock = vi.fn() const compressorConstructorMock = vi.fn() @@ -55,6 +66,7 @@ vi.mock('../../packages/server/src/services/logger', () => ({ vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({ calcAndUpdateUsage: calcAndUpdateUsageMock, estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock, + updateMessageContextTokenUsage: updateMessageContextTokenUsageMock, })) vi.mock('../../packages/server/src/services/hermes/run-chat/message-format', () => ({ @@ -69,6 +81,7 @@ describe('run chat compression trigger', () => { getModelContextLengthMock.mockReset() calcAndUpdateUsageMock.mockReset() estimateUsageTokensFromMessagesMock.mockReset() + updateMessageContextTokenUsageMock.mockClear() compressorCompressMock.mockReset() compressorConstructorMock.mockReset() readConfigYamlForProfileMock.mockReset() @@ -189,13 +202,14 @@ describe('run chat compression trigger', () => { }, }) + const emit = vi.fn() const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression') const history = await buildCompressedHistory( 'session-1', 'default', 'http://upstream', undefined, - vi.fn(), + emit, new Map(), {}, vi.fn(async () => 120_000), @@ -203,6 +217,13 @@ describe('run chat compression trigger', () => { expect(history).toEqual([{ role: 'user', content: 'compressed by full context estimate' }]) expect(compressorCompressMock).toHaveBeenCalledTimes(1) + expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith( + 'session-1', + expect.any(Object), + emit, + 1_000, + { inputTokens: 1_000, outputTokens: 0 }, + ) }) it('emits full context token usage when the full estimate is under threshold', async () => { diff --git a/tests/server/run-chat-usage.test.ts b/tests/server/run-chat-usage.test.ts index b508dcf..fd91540 100644 --- a/tests/server/run-chat-usage.test.ts +++ b/tests/server/run-chat-usage.test.ts @@ -1,6 +1,10 @@ -import { describe, expect, it } from 'vitest' +import { describe, expect, it, vi } from 'vitest' import { countTokens } from '../../packages/server/src/lib/context-compressor' -import { estimateUsageTokensFromMessages } from '../../packages/server/src/services/hermes/run-chat/usage' +import { + contextTokensWithCachedOverhead, + estimateUsageTokensFromMessages, + updateMessageContextTokenUsage, +} from '../../packages/server/src/services/hermes/run-chat/usage' describe('run-chat usage token estimates', () => { it('counts message content instead of serialized message payloads', () => { @@ -30,4 +34,58 @@ describe('run-chat usage token estimates', () => { expect(usage.inputTokens).toBe(0) expect(usage.outputTokens).toBe(countTokens('calling tool') + countTokens(String(messages[0].tool_calls || ''))) }) + + it('adds cached bridge fixed context when updating full context usage', () => { + const emit = vi.fn() + const state = { + messages: [], + isWorking: false, + events: [], + queue: [], + bridgeContext: { fixedContextTokens: 20_000 }, + } as any + + const contextTokens = updateMessageContextTokenUsage( + 'session-1', + state, + emit, + 1_569, + { inputTokens: 1_200, outputTokens: 369 }, + ) + + expect(contextTokens).toBe(21_569) + expect(state.contextTokens).toBe(21_569) + expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({ + session_id: 'session-1', + inputTokens: 1_200, + outputTokens: 369, + contextTokens: 21_569, + })) + }) + + it('falls back to message tokens when bridge fixed context is missing', () => { + const emit = vi.fn() + const state = { + messages: [], + isWorking: false, + events: [], + queue: [], + } as any + + expect(contextTokensWithCachedOverhead(state, 1_569)).toBe(1_569) + + const contextTokens = updateMessageContextTokenUsage( + 'session-1', + state, + emit, + 1_569, + { inputTokens: 1_200, outputTokens: 369 }, + ) + + expect(contextTokens).toBe(1_569) + expect(state.contextTokens).toBe(1_569) + expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({ + contextTokens: 1_569, + })) + }) })