fix compression context usage accounting (#924)

This commit is contained in:
ekko
2026-05-22 09:46:50 +08:00
committed by GitHub
parent b5f0215beb
commit c3538a6b44
11 changed files with 454 additions and 61 deletions
+15 -3
View File
@@ -569,14 +569,16 @@ export const useChatStore = defineStore('chat', () => {
compressed: null,
})
} else if (e.event === 'compression.completed') {
const afterTokens = e.contextTokens || e.afterTokens || 0
setCompressionState({
compressing: false,
messageCount: e.totalMessages || 0,
beforeTokens: e.beforeTokens || 0,
afterTokens: e.afterTokens || 0,
afterTokens,
compressed: e.compressed ?? false,
error: e.error,
})
if (e.contextTokens != null) activeSession.value!.contextTokens = e.contextTokens
} else if (e.event === 'abort.started') {
setAbortState({ aborting: true, synced: null })
} else if (e.event === 'abort.completed') {
@@ -1073,14 +1075,19 @@ export const useChatStore = defineStore('chat', () => {
}
case 'compression.completed': {
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
setCompressionState({
compressing: false,
messageCount: (evt as any).totalMessages || 0,
beforeTokens: (evt as any).beforeTokens || 0,
afterTokens: (evt as any).afterTokens || 0,
afterTokens,
compressed: (evt as any).compressed ?? false,
error: (evt as any).error,
})
if ((evt as any).contextTokens != null) {
const target = sessions.value.find(s => s.id === sid)
if (target) target.contextTokens = (evt as any).contextTokens
}
// Auto-clear after 5s
setTimeout(() => {
if (compressionState.value && !compressionState.value.compressing) {
@@ -1520,14 +1527,19 @@ export const useChatStore = defineStore('chat', () => {
}
case 'compression.completed': {
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
setCompressionState({
compressing: false,
messageCount: (evt as any).totalMessages || 0,
beforeTokens: (evt as any).beforeTokens || 0,
afterTokens: (evt as any).afterTokens || 0,
afterTokens,
compressed: (evt as any).compressed ?? false,
error: (evt as any).error,
})
if ((evt as any).contextTokens != null) {
const target = sessions.value.find(s => s.id === sid)
if (target) target.contextTokens = (evt as any).contextTokens
}
setTimeout(() => {
if (compressionState.value && !compressionState.value.compressing) {
setCompressionState(null)
@@ -93,9 +93,16 @@ export interface AgentBridgeRunResult extends AgentBridgeResponse {
export interface AgentBridgeContextEstimate extends AgentBridgeResponse {
session_id: string
token_count?: number | null
fixed_context_tokens?: number | null
system_prompt_tokens?: number | null
tool_tokens?: number | null
message_count: number
tool_count: number
tool_names?: string[]
system_prompt_chars: number
profile?: string
model?: string
provider?: string
}
export interface AgentBridgeCommandResult extends AgentBridgeResponse {
@@ -668,30 +668,79 @@ class AgentPool:
agent._compress_context = wrapped_compress_context
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
try:
from agent.model_metadata import estimate_request_tokens_rough
except Exception:
return None
def _agent_system_prompt(self, agent: Any, system_message: Any = None) -> str:
prompt = str(getattr(agent, "_cached_system_prompt", "") or "")
if not prompt:
if prompt:
return prompt
try:
build_prompt = getattr(agent, "_build_system_prompt", None)
if callable(build_prompt):
prompt = str(build_prompt(system_message) or "")
return str(build_prompt(system_message) or "")
except Exception:
prompt = str(system_message or "")
return str(system_message or "")
return str(system_message or "")
def _agent_tool_names(self, tools: Any) -> list[str]:
if not isinstance(tools, list):
return []
names: list[str] = []
for tool in tools:
name = ""
if isinstance(tool, dict):
function = tool.get("function")
if isinstance(function, dict):
name = str(function.get("name") or "")
if not name:
name = str(tool.get("name") or "")
else:
name = str(getattr(tool, "name", "") or "")
if name:
names.append(name)
return names
def _estimate_context_info(self, agent: Any, messages: Any, system_message: Any = None) -> dict[str, Any]:
try:
estimate = estimate_request_tokens_rough(
messages if isinstance(messages, list) else [],
system_prompt=prompt,
tools=getattr(agent, "tools", None) or None,
)
return int(estimate) if isinstance(estimate, (int, float)) and estimate > 0 else None
from agent.model_metadata import estimate_request_tokens_rough
except Exception:
return None
return {}
prompt = self._agent_system_prompt(agent, system_message)
tools = getattr(agent, "tools", None) or []
message_list = messages if isinstance(messages, list) else []
try:
token_count = estimate_request_tokens_rough(message_list, system_prompt=prompt, tools=tools or None)
fixed_context_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=tools or None)
system_prompt_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=None)
tool_tokens = max(0, int(fixed_context_tokens or 0) - int(system_prompt_tokens or 0))
return {
"token_count": int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None,
"fixed_context_tokens": int(fixed_context_tokens) if isinstance(fixed_context_tokens, (int, float)) and fixed_context_tokens >= 0 else None,
"system_prompt_tokens": int(system_prompt_tokens) if isinstance(system_prompt_tokens, (int, float)) and system_prompt_tokens >= 0 else None,
"tool_tokens": tool_tokens,
"message_count": len(message_list),
"tool_count": len(tools) if isinstance(tools, list) else 0,
"tool_names": self._agent_tool_names(tools),
"system_prompt_chars": len(prompt),
}
except Exception:
return {}
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
token_count = self._estimate_context_info(agent, messages, system_message).get("token_count")
return int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None
def _bridge_context_ready_event(self, session: AgentSession, instructions: str | None, profile: str | None) -> dict[str, Any]:
info = self._estimate_context_info(session.agent, [], instructions)
event = {
"event": "bridge.context.ready",
"session_id": session.session_id,
"profile": profile or session.config.get("profile") or "default",
"model": session.config.get("model"),
"provider": session.config.get("provider"),
**info,
}
session.config["context_info"] = event
return event
def estimate_context(
self,
@@ -703,24 +752,23 @@ class AgentPool:
provider: str | None = None,
) -> dict[str, Any]:
session = self.get_or_create(session_id, profile=profile, model=model, provider=provider)
token_count = self._estimate_context_tokens(session.agent, messages or [], instructions)
tools = getattr(session.agent, "tools", None) or []
prompt = str(getattr(session.agent, "_cached_system_prompt", "") or "")
context_info = self._estimate_context_info(session.agent, messages or [], instructions)
print(
"[hermes_bridge] context estimate "
f"session={session_id} profile={profile or 'default'} "
f"messages={len(messages or [])} system_prompt_chars={len(prompt)} "
f"tools={len(tools) if isinstance(tools, list) else 0} "
f"tokens={token_count if token_count is not None else 'unknown'}",
f"messages={len(messages or [])} system_prompt_chars={context_info.get('system_prompt_chars') or 0} "
f"tools={context_info.get('tool_count') or 0} "
f"fixed_tokens={context_info.get('fixed_context_tokens') if context_info.get('fixed_context_tokens') is not None else 'unknown'} "
f"tokens={context_info.get('token_count') if context_info.get('token_count') is not None else 'unknown'}",
file=sys.stderr,
flush=True,
)
return {
"session_id": session_id,
"token_count": token_count,
"message_count": len(messages or []),
"tool_count": len(tools) if isinstance(tools, list) else 0,
"system_prompt_chars": len(prompt),
"profile": profile or session.config.get("profile") or "default",
"model": session.config.get("model"),
"provider": session.config.get("provider"),
**context_info,
}
def respond_compression(
@@ -1062,6 +1110,9 @@ class AgentPool:
session.running = True
session.current_run_id = run_id
session.last_used_at = time.time()
context_event = self._bridge_context_ready_event(session, instructions, profile)
if context_event:
record.events.append(_jsonable(context_event))
thread = threading.Thread(
target=self._run_chat,
@@ -13,7 +13,7 @@ import { getModelContextLength } from '../model-context'
import { readConfigYamlForProfile } from '../../config-helpers'
import { logger } from '../../logger'
import { bridgeLogger } from '../../logger'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages, updateMessageContextTokenUsage } from './usage'
import { isAssistantMessageSendable } from './message-format'
import type { ChatMessage, CompressionConfig as CompressorConfig } from '../../../lib/context-compressor'
import type { SessionState, BridgeCompressionResult } from './types'
@@ -69,6 +69,23 @@ function buildSnapshotHistory(
]
}
export async function buildSnapshotAwareHistory(
sessionId: string,
profile: string,
history: ChatMessage[],
modelContext: { model?: string | null; provider?: string | null } = {},
): Promise<ChatMessage[]> {
const snapshot = getCompressionSnapshot(sessionId)
if (!snapshot) return history
const contextLength = getModelContextLength({
profile,
model: modelContext.model,
provider: modelContext.provider,
})
const compressionConfig = await getRunChatCompressionConfig(profile, contextLength)
return buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
}
function clampRatio(value: unknown, fallback: number, min: number, max: number): number {
const n = typeof value === 'number' && Number.isFinite(value) ? value : fallback
return Math.min(max, Math.max(min, n))
@@ -343,21 +360,26 @@ export async function compressHistory(
provider: modelContext.provider || session?.provider,
})
const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit)
const compressedMeta = {
const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens
const compressedMeta: any = {
event: 'compression.completed' as const,
compressed: result.meta.compressed,
llmCompressed: result.meta.llmCompressed,
totalMessages: result.meta.totalMessages,
resultMessages: result.messages.length,
beforeTokens: totalTokens,
afterTokens: afterTokens.inputTokens + afterTokens.outputTokens,
afterTokens: compressedAfterTokens,
summaryTokens: result.meta.summaryTokenEstimate,
verbatimCount: result.meta.verbatimCount,
compressedStartIndex: result.meta.compressedStartIndex,
}
replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta)
logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)',
sessionId, result.messages.length, afterTokens.inputTokens + afterTokens.outputTokens, totalTokens)
sessionId, result.messages.length, compressedAfterTokens, totalTokens)
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedAfterTokens, afterTokens)
if (compressedContextTokens != null) {
compressedMeta.contextTokens = compressedContextTokens
}
emit('compression.completed', compressedMeta)
const compressed = result.messages.map(m => {
@@ -8,10 +8,17 @@ import { getSystemPrompt } from '../../../lib/llm-prompt'
import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store'
import { updateUsage } from '../../../db/hermes/usage-store'
import { logger, bridgeLogger } from '../../logger'
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
import { AgentBridgeClient, type AgentBridgeContextEstimate, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks'
import { buildCompressedHistory, buildDbHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression'
import {
calcAndUpdateUsage,
contextTokensWithCachedOverhead,
estimateUsageTokensFromMessages,
getCachedBridgeContextOverhead,
updateContextTokenUsage,
updateMessageContextTokenUsage,
} from './usage'
import {
flushBridgePendingToDb,
ensureOpenBridgeAssistantMessage,
@@ -62,6 +69,28 @@ export function bridgeTerminalError(chunk: Pick<AgentBridgeOutput, 'status' | 'e
return null
}
function finiteToken(value: unknown): number | undefined {
return typeof value === 'number' && Number.isFinite(value) && value >= 0
? Math.floor(value)
: undefined
}
function cacheBridgeContext(state: SessionState, data: Record<string, unknown> | AgentBridgeContextEstimate) {
const fixedContextTokens = finiteToken(data.fixed_context_tokens)
if (fixedContextTokens == null) return
state.bridgeContext = {
fixedContextTokens,
systemPromptTokens: finiteToken(data.system_prompt_tokens),
toolTokens: finiteToken(data.tool_tokens),
systemPromptChars: finiteToken(data.system_prompt_chars),
toolCount: finiteToken(data.tool_count),
toolNames: Array.isArray(data.tool_names) ? data.tool_names.map(String) : undefined,
profile: typeof data.profile === 'string' ? data.profile : state.bridgeContext?.profile,
model: typeof data.model === 'string' ? data.model : state.bridgeContext?.model,
provider: typeof data.provider === 'string' ? data.provider : state.bridgeContext?.provider,
}
}
export async function handleBridgeRun(
nsp: ReturnType<Server['of']>,
socket: Socket,
@@ -168,6 +197,11 @@ export async function handleBridgeRun(
sessionMap,
{ model: resolvedModel, provider: resolvedProvider },
async (messages) => {
const cachedOverhead = getCachedBridgeContextOverhead(state)
if (cachedOverhead != null) {
const messageUsage = estimateUsageTokensFromMessages(messages)
return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens
}
const estimate = await bridge.contextEstimate(
session_id,
messages,
@@ -175,6 +209,7 @@ export async function handleBridgeRun(
profile,
{ model: resolvedModel, provider: resolvedProvider },
)
cacheBridgeContext(state, estimate)
bridgeLogger.info({
sessionId: session_id,
profile,
@@ -183,6 +218,7 @@ export async function handleBridgeRun(
messages: estimate.message_count,
toolCount: estimate.tool_count,
systemPromptChars: estimate.system_prompt_chars,
fixedContextTokens: estimate.fixed_context_tokens,
fullContextTokens: estimate.token_count,
}, '[chat-run-socket] full context estimate')
return estimate.token_count
@@ -308,7 +344,35 @@ async function refreshFinalContextUsage(args: {
bridge: AgentBridgeClient
}): Promise<number | undefined> {
try {
const finalHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
const finalHistory = await buildSnapshotAwareHistory(
args.sessionId,
args.profile,
dbHistory,
{ model: args.model, provider: args.provider },
)
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
if (getCachedBridgeContextOverhead(args.state) != null) {
const contextTokens = updateMessageContextTokenUsage(
args.sessionId,
args.state,
args.emit,
finalMessageTokens,
args.usage,
)
bridgeLogger.info({
sessionId: args.sessionId,
profile: args.profile,
model: args.model,
provider: args.provider,
messages: finalHistory.length,
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
messageTokens: finalMessageTokens,
fullContextTokens: contextTokens,
}, '[chat-run-socket] final cached context estimate')
return contextTokens
}
const estimate = await args.bridge.contextEstimate(
args.sessionId,
finalHistory,
@@ -316,18 +380,13 @@ async function refreshFinalContextUsage(args: {
args.profile,
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
)
cacheBridgeContext(args.state, estimate)
const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0
? Math.floor(estimate.token_count)
: undefined
if (contextTokens == null) return args.state.contextTokens
args.state.contextTokens = contextTokens
args.emit('usage.updated', {
event: 'usage.updated',
inputTokens: args.usage.inputTokens,
outputTokens: args.usage.outputTokens,
contextTokens,
})
updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage)
bridgeLogger.info({
sessionId: args.sessionId,
profile: args.profile,
@@ -378,7 +437,17 @@ async function applyBridgeChunkAsync(
for (const ev of chunk.events || []) {
const evType = ev.event as string | undefined
if (evType === 'tool.started') {
if (evType === 'bridge.context.ready') {
cacheBridgeContext(state, ev)
const usage = await calcAndUpdateUsage(sessionId, state, emit)
updateMessageContextTokenUsage(
sessionId,
state,
emit,
usage.inputTokens + usage.outputTokens,
usage,
)
} else if (evType === 'tool.started') {
flushBridgePendingToDb(state, sessionId, runMarker)
const toolName = (ev.tool_name as string) || ''
const args = ev.args as Record<string, unknown> | undefined
@@ -498,6 +567,11 @@ async function applyBridgeChunkAsync(
const compressionResult = ev.request_id
? state.bridgeCompressionResults?.[String(ev.request_id)]
: undefined
const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens)
const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null
? contextTokensWithCachedOverhead(state, messageAfterTokens)
: bridgeAfterContextTokens ?? messageAfterTokens
const payload = {
event: 'compression.completed',
run_id: chunk.run_id,
@@ -507,9 +581,8 @@ async function applyBridgeChunkAsync(
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
afterTokens: typeof ev.result_approx_tokens === 'number' && Number.isFinite(ev.result_approx_tokens) && ev.result_approx_tokens > 0
? ev.result_approx_tokens
: compressionResult?.afterTokens,
afterTokens: messageAfterTokens ?? bridgeAfterContextTokens,
contextTokens: afterContextTokens,
summaryTokens: compressionResult?.summaryTokens,
verbatimCount: compressionResult?.verbatimCount,
compressedStartIndex: compressionResult?.compressedStartIndex,
@@ -520,7 +593,12 @@ async function applyBridgeChunkAsync(
}
replaceState(sessionMap, sessionId, 'compression.completed', payload)
emit('compression.completed', payload)
await calcAndUpdateUsage(sessionId, state, emit)
const usage = await calcAndUpdateUsage(sessionId, state, emit)
if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) {
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage)
} else {
updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage)
}
} else if (evType === 'bridge.compression.failed') {
const payload = {
event: 'compression.completed',
@@ -5,7 +5,7 @@ import type { AgentBridgeClient } from '../agent-bridge'
import { flushBridgePendingToDb } from './bridge-message'
import { buildDbHistory, estimateSnapshotAwareHistoryUsage, forceCompressBridgeHistory, getOrCreateSession, replaceState } from './compression'
import { handleAbort } from './abort'
import { calcAndUpdateUsage } from './usage'
import { calcAndUpdateUsage, contextTokensWithCachedOverhead, updateMessageContextTokenUsage } from './usage'
import type { ContentBlock, QueuedRun, SessionState } from './types'
type CommandName =
@@ -232,10 +232,11 @@ export async function handleSessionCommand(
try {
const history = await buildDbHistory(sessionId, { excludeLastUser: true })
const usageEstimate = estimateSnapshotAwareHistoryUsage(sessionId, history)
const beforeContextTokens = contextTokensWithCachedOverhead(state, usageEstimate.tokenCount)
emit('compression.started', {
event: 'compression.started',
message_count: usageEstimate.messageCount,
token_count: usageEstimate.tokenCount,
token_count: beforeContextTokens,
source: 'command',
})
const result = await forceCompressBridgeHistory(
@@ -244,27 +245,32 @@ export async function handleSessionCommand(
[],
)
state.bridgeCompressionResults = state.bridgeCompressionResults || {}
await calcAndUpdateUsage(sessionId, state, emit)
const usage = await calcAndUpdateUsage(sessionId, state, emit)
const afterContextTokens = contextTokensWithCachedOverhead(state, result.afterTokens)
emit('compression.completed', {
event: 'compression.completed',
compressed: result.compressed,
llmCompressed: result.llmCompressed,
totalMessages: result.beforeMessages,
resultMessages: result.resultMessages,
beforeTokens: result.beforeTokens,
beforeTokens: beforeContextTokens,
afterTokens: result.afterTokens,
summaryTokens: result.summaryTokens,
verbatimCount: result.verbatimCount,
compressedStartIndex: result.compressedStartIndex,
contextTokens: afterContextTokens,
source: 'command',
})
updateMessageContextTokenUsage(sessionId, state, emit, result.afterTokens, usage)
emitCommand({
action: 'compress',
message: `Compression completed: ${result.beforeMessages} -> ${result.resultMessages} messages, ${result.beforeTokens} -> ${result.afterTokens} tokens.`,
message: `Compression completed: ${result.beforeMessages} -> ${result.resultMessages} messages, ${beforeContextTokens} -> ${afterContextTokens} tokens.`,
beforeMessages: result.beforeMessages,
resultMessages: result.resultMessages,
beforeTokens: result.beforeTokens,
afterTokens: result.afterTokens,
beforeTokens: beforeContextTokens,
afterTokens: afterContextTokens,
messageBeforeTokens: result.beforeTokens,
messageAfterTokens: result.afterTokens,
compressed: result.compressed,
})
} catch (err) {
@@ -47,6 +47,7 @@ export interface SessionState {
inputTokens?: number
outputTokens?: number
contextTokens?: number
bridgeContext?: BridgeContextState
isAborting?: boolean
queue: QueuedRun[]
responseRun?: ResponseRunState
@@ -72,6 +73,18 @@ export interface ResponseRunState {
toolCalls: Map<string, any>
}
export interface BridgeContextState {
fixedContextTokens?: number
systemPromptTokens?: number
toolTokens?: number
systemPromptChars?: number
toolCount?: number
toolNames?: string[]
profile?: string
model?: string
provider?: string
}
export type ChatRunSource = 'api_server' | 'cli'
export interface BridgeCompressionResult {
@@ -78,3 +78,60 @@ export async function calcAndUpdateUsage(
return { inputTokens: 0, outputTokens: 0 }
}
}
export function updateContextTokenUsage(
sid: string,
state: SessionState,
emit: (event: string, payload: any) => void,
contextTokens: number | null | undefined,
usage?: { inputTokens: number; outputTokens: number },
): number | undefined {
if (typeof contextTokens !== 'number' || !Number.isFinite(contextTokens) || contextTokens < 0) {
return state.contextTokens
}
const normalizedContextTokens = Math.floor(contextTokens)
state.contextTokens = normalizedContextTokens
emit('usage.updated', {
event: 'usage.updated',
session_id: sid,
inputTokens: usage?.inputTokens ?? state.inputTokens ?? 0,
outputTokens: usage?.outputTokens ?? state.outputTokens ?? 0,
contextTokens: normalizedContextTokens,
})
return normalizedContextTokens
}
export function getCachedBridgeContextOverhead(state: SessionState): number | undefined {
const fixedContextTokens = state.bridgeContext?.fixedContextTokens
if (typeof fixedContextTokens !== 'number' || !Number.isFinite(fixedContextTokens) || fixedContextTokens < 0) {
return undefined
}
return Math.floor(fixedContextTokens)
}
export function contextTokensWithCachedOverhead(state: SessionState, messageTokens: number): number {
const normalizedMessageTokens = Math.max(0, Math.floor(messageTokens))
const fixedContextTokens = getCachedBridgeContextOverhead(state)
return fixedContextTokens == null
? normalizedMessageTokens
: fixedContextTokens + normalizedMessageTokens
}
export function updateMessageContextTokenUsage(
sid: string,
state: SessionState,
emit: (event: string, payload: any) => void,
messageTokens: number | null | undefined,
usage?: { inputTokens: number; outputTokens: number },
): number | undefined {
if (typeof messageTokens !== 'number' || !Number.isFinite(messageTokens) || messageTokens < 0) {
return state.contextTokens
}
return updateContextTokenUsage(
sid,
state,
emit,
contextTokensWithCachedOverhead(state, messageTokens),
usage,
)
}
@@ -9,11 +9,26 @@ const updateSessionStatsMock = vi.fn()
const updateUsageMock = vi.fn()
const buildCompressedHistoryMock = vi.fn()
const buildDbHistoryMock = vi.fn()
const buildSnapshotAwareHistoryMock = vi.fn(async (_sessionId: string, _profile: string, history: any[]) => history)
const pushStateMock = vi.fn()
const replaceStateMock = vi.fn()
const forceCompressBridgeHistoryMock = vi.fn()
const calcAndUpdateUsageMock = vi.fn()
const estimateUsageTokensFromMessagesMock = vi.fn()
const updateContextTokenUsageMock = vi.fn((sid: string, state: any, emit: any, contextTokens: number, usage?: { inputTokens: number; outputTokens: number }) => {
state.contextTokens = contextTokens
emit('usage.updated', {
event: 'usage.updated',
session_id: sid,
inputTokens: usage?.inputTokens ?? state.inputTokens ?? 0,
outputTokens: usage?.outputTokens ?? state.outputTokens ?? 0,
contextTokens,
})
return contextTokens
})
const getCachedBridgeContextOverheadMock = vi.fn(() => undefined)
const contextTokensWithCachedOverheadMock = vi.fn((_state: any, messageTokens: number) => messageTokens)
const updateMessageContextTokenUsageMock = vi.fn((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, state, emit, messageTokens, usage))
const flushBridgePendingToDbMock = vi.fn()
const ensureOpenBridgeAssistantMessageMock = vi.fn()
const syncBridgeReasoningToMessageMock = vi.fn()
@@ -45,6 +60,7 @@ vi.mock('../../packages/server/src/services/logger', () => ({
vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () => ({
buildCompressedHistory: buildCompressedHistoryMock,
buildDbHistory: buildDbHistoryMock,
buildSnapshotAwareHistory: buildSnapshotAwareHistoryMock,
pushState: pushStateMock,
replaceState: replaceStateMock,
forceCompressBridgeHistory: forceCompressBridgeHistoryMock,
@@ -53,6 +69,10 @@ vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () =>
vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({
calcAndUpdateUsage: calcAndUpdateUsageMock,
estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock,
getCachedBridgeContextOverhead: getCachedBridgeContextOverheadMock,
contextTokensWithCachedOverhead: contextTokensWithCachedOverheadMock,
updateContextTokenUsage: updateContextTokenUsageMock,
updateMessageContextTokenUsage: updateMessageContextTokenUsageMock,
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/bridge-message', () => ({
@@ -103,7 +123,11 @@ describe('bridge run final context usage', () => {
{ role: 'user', content: 'hello' },
{ role: 'assistant', content: 'done' },
])
buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history)
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 })
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 })
getCachedBridgeContextOverheadMock.mockReturnValue(undefined)
contextTokensWithCachedOverheadMock.mockImplementation((_state: any, messageTokens: number) => messageTokens)
})
it('refreshes full context tokens when a bridge run completes', async () => {
@@ -161,6 +185,50 @@ describe('bridge run final context usage', () => {
}))
})
it('uses cached fixed context instead of bridge estimate when available', async () => {
const emit = vi.fn()
const nsp = makeNamespace(emit)
const socket = makeSocket()
const state = makeState()
state.bridgeContext = { fixedContextTokens: 20_000 }
const sessionMap = new Map([['session-1', state]])
getCachedBridgeContextOverheadMock.mockReturnValue(20_000)
updateMessageContextTokenUsageMock.mockImplementation((sid: string, targetState: any, targetEmit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, targetState, targetEmit, 20_000 + messageTokens, usage))
const bridge = {
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
contextEstimate: vi.fn(),
streamOutput: vi.fn(async function* () {
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
}),
} as any
const { handleBridgeRun } = await import('../../packages/server/src/services/hermes/run-chat/handle-bridge-run')
await handleBridgeRun(
nsp,
socket,
{ input: 'hello', session_id: 'session-1' },
'default',
sessionMap,
bridge,
false,
vi.fn(),
vi.fn(),
)
expect(bridge.contextEstimate).not.toHaveBeenCalled()
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
'session-1',
state,
expect.any(Function),
18,
{ inputTokens: 11, outputTokens: 7 },
)
expect(state.contextTokens).toBe(20_018)
expect(emit).toHaveBeenCalledWith('run.completed', expect.objectContaining({
contextTokens: 20_018,
}))
})
it('refreshes full context tokens when a bridge run fails', async () => {
const emit = vi.fn()
const nsp = makeNamespace(emit)
+22 -1
View File
@@ -6,6 +6,17 @@ const getCompressionSnapshotMock = vi.fn()
const getModelContextLengthMock = vi.fn()
const calcAndUpdateUsageMock = vi.fn()
const estimateUsageTokensFromMessagesMock = vi.fn()
const updateMessageContextTokenUsageMock = vi.fn((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => {
state.contextTokens = messageTokens
emit('usage.updated', {
event: 'usage.updated',
session_id: sid,
inputTokens: usage?.inputTokens ?? state.inputTokens ?? 0,
outputTokens: usage?.outputTokens ?? state.outputTokens ?? 0,
contextTokens: messageTokens,
})
return messageTokens
})
const compressorCompressMock = vi.fn()
const readConfigYamlForProfileMock = vi.fn()
const compressorConstructorMock = vi.fn()
@@ -55,6 +66,7 @@ vi.mock('../../packages/server/src/services/logger', () => ({
vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({
calcAndUpdateUsage: calcAndUpdateUsageMock,
estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock,
updateMessageContextTokenUsage: updateMessageContextTokenUsageMock,
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/message-format', () => ({
@@ -69,6 +81,7 @@ describe('run chat compression trigger', () => {
getModelContextLengthMock.mockReset()
calcAndUpdateUsageMock.mockReset()
estimateUsageTokensFromMessagesMock.mockReset()
updateMessageContextTokenUsageMock.mockClear()
compressorCompressMock.mockReset()
compressorConstructorMock.mockReset()
readConfigYamlForProfileMock.mockReset()
@@ -189,13 +202,14 @@ describe('run chat compression trigger', () => {
},
})
const emit = vi.fn()
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
const history = await buildCompressedHistory(
'session-1',
'default',
'http://upstream',
undefined,
vi.fn(),
emit,
new Map(),
{},
vi.fn(async () => 120_000),
@@ -203,6 +217,13 @@ describe('run chat compression trigger', () => {
expect(history).toEqual([{ role: 'user', content: 'compressed by full context estimate' }])
expect(compressorCompressMock).toHaveBeenCalledTimes(1)
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
'session-1',
expect.any(Object),
emit,
1_000,
{ inputTokens: 1_000, outputTokens: 0 },
)
})
it('emits full context token usage when the full estimate is under threshold', async () => {
+60 -2
View File
@@ -1,6 +1,10 @@
import { describe, expect, it } from 'vitest'
import { describe, expect, it, vi } from 'vitest'
import { countTokens } from '../../packages/server/src/lib/context-compressor'
import { estimateUsageTokensFromMessages } from '../../packages/server/src/services/hermes/run-chat/usage'
import {
contextTokensWithCachedOverhead,
estimateUsageTokensFromMessages,
updateMessageContextTokenUsage,
} from '../../packages/server/src/services/hermes/run-chat/usage'
describe('run-chat usage token estimates', () => {
it('counts message content instead of serialized message payloads', () => {
@@ -30,4 +34,58 @@ describe('run-chat usage token estimates', () => {
expect(usage.inputTokens).toBe(0)
expect(usage.outputTokens).toBe(countTokens('calling tool') + countTokens(String(messages[0].tool_calls || '')))
})
it('adds cached bridge fixed context when updating full context usage', () => {
const emit = vi.fn()
const state = {
messages: [],
isWorking: false,
events: [],
queue: [],
bridgeContext: { fixedContextTokens: 20_000 },
} as any
const contextTokens = updateMessageContextTokenUsage(
'session-1',
state,
emit,
1_569,
{ inputTokens: 1_200, outputTokens: 369 },
)
expect(contextTokens).toBe(21_569)
expect(state.contextTokens).toBe(21_569)
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
session_id: 'session-1',
inputTokens: 1_200,
outputTokens: 369,
contextTokens: 21_569,
}))
})
it('falls back to message tokens when bridge fixed context is missing', () => {
const emit = vi.fn()
const state = {
messages: [],
isWorking: false,
events: [],
queue: [],
} as any
expect(contextTokensWithCachedOverhead(state, 1_569)).toBe(1_569)
const contextTokens = updateMessageContextTokenUsage(
'session-1',
state,
emit,
1_569,
{ inputTokens: 1_200, outputTokens: 369 },
)
expect(contextTokens).toBe(1_569)
expect(state.contextTokens).toBe(1_569)
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
contextTokens: 1_569,
}))
})
})