fix compression context usage accounting (#924)

This commit is contained in:
ekko
2026-05-22 09:46:50 +08:00
committed by GitHub
parent b5f0215beb
commit c3538a6b44
11 changed files with 454 additions and 61 deletions
@@ -8,10 +8,17 @@ import { getSystemPrompt } from '../../../lib/llm-prompt'
import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store'
import { updateUsage } from '../../../db/hermes/usage-store'
import { logger, bridgeLogger } from '../../logger'
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
import { AgentBridgeClient, type AgentBridgeContextEstimate, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks'
import { buildCompressedHistory, buildDbHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression'
import {
calcAndUpdateUsage,
contextTokensWithCachedOverhead,
estimateUsageTokensFromMessages,
getCachedBridgeContextOverhead,
updateContextTokenUsage,
updateMessageContextTokenUsage,
} from './usage'
import {
flushBridgePendingToDb,
ensureOpenBridgeAssistantMessage,
@@ -62,6 +69,28 @@ export function bridgeTerminalError(chunk: Pick<AgentBridgeOutput, 'status' | 'e
return null
}
function finiteToken(value: unknown): number | undefined {
return typeof value === 'number' && Number.isFinite(value) && value >= 0
? Math.floor(value)
: undefined
}
function cacheBridgeContext(state: SessionState, data: Record<string, unknown> | AgentBridgeContextEstimate) {
const fixedContextTokens = finiteToken(data.fixed_context_tokens)
if (fixedContextTokens == null) return
state.bridgeContext = {
fixedContextTokens,
systemPromptTokens: finiteToken(data.system_prompt_tokens),
toolTokens: finiteToken(data.tool_tokens),
systemPromptChars: finiteToken(data.system_prompt_chars),
toolCount: finiteToken(data.tool_count),
toolNames: Array.isArray(data.tool_names) ? data.tool_names.map(String) : undefined,
profile: typeof data.profile === 'string' ? data.profile : state.bridgeContext?.profile,
model: typeof data.model === 'string' ? data.model : state.bridgeContext?.model,
provider: typeof data.provider === 'string' ? data.provider : state.bridgeContext?.provider,
}
}
export async function handleBridgeRun(
nsp: ReturnType<Server['of']>,
socket: Socket,
@@ -168,6 +197,11 @@ export async function handleBridgeRun(
sessionMap,
{ model: resolvedModel, provider: resolvedProvider },
async (messages) => {
const cachedOverhead = getCachedBridgeContextOverhead(state)
if (cachedOverhead != null) {
const messageUsage = estimateUsageTokensFromMessages(messages)
return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens
}
const estimate = await bridge.contextEstimate(
session_id,
messages,
@@ -175,6 +209,7 @@ export async function handleBridgeRun(
profile,
{ model: resolvedModel, provider: resolvedProvider },
)
cacheBridgeContext(state, estimate)
bridgeLogger.info({
sessionId: session_id,
profile,
@@ -183,6 +218,7 @@ export async function handleBridgeRun(
messages: estimate.message_count,
toolCount: estimate.tool_count,
systemPromptChars: estimate.system_prompt_chars,
fixedContextTokens: estimate.fixed_context_tokens,
fullContextTokens: estimate.token_count,
}, '[chat-run-socket] full context estimate')
return estimate.token_count
@@ -308,7 +344,35 @@ async function refreshFinalContextUsage(args: {
bridge: AgentBridgeClient
}): Promise<number | undefined> {
try {
const finalHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
const finalHistory = await buildSnapshotAwareHistory(
args.sessionId,
args.profile,
dbHistory,
{ model: args.model, provider: args.provider },
)
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
if (getCachedBridgeContextOverhead(args.state) != null) {
const contextTokens = updateMessageContextTokenUsage(
args.sessionId,
args.state,
args.emit,
finalMessageTokens,
args.usage,
)
bridgeLogger.info({
sessionId: args.sessionId,
profile: args.profile,
model: args.model,
provider: args.provider,
messages: finalHistory.length,
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
messageTokens: finalMessageTokens,
fullContextTokens: contextTokens,
}, '[chat-run-socket] final cached context estimate')
return contextTokens
}
const estimate = await args.bridge.contextEstimate(
args.sessionId,
finalHistory,
@@ -316,18 +380,13 @@ async function refreshFinalContextUsage(args: {
args.profile,
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
)
cacheBridgeContext(args.state, estimate)
const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0
? Math.floor(estimate.token_count)
: undefined
if (contextTokens == null) return args.state.contextTokens
args.state.contextTokens = contextTokens
args.emit('usage.updated', {
event: 'usage.updated',
inputTokens: args.usage.inputTokens,
outputTokens: args.usage.outputTokens,
contextTokens,
})
updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage)
bridgeLogger.info({
sessionId: args.sessionId,
profile: args.profile,
@@ -378,7 +437,17 @@ async function applyBridgeChunkAsync(
for (const ev of chunk.events || []) {
const evType = ev.event as string | undefined
if (evType === 'tool.started') {
if (evType === 'bridge.context.ready') {
cacheBridgeContext(state, ev)
const usage = await calcAndUpdateUsage(sessionId, state, emit)
updateMessageContextTokenUsage(
sessionId,
state,
emit,
usage.inputTokens + usage.outputTokens,
usage,
)
} else if (evType === 'tool.started') {
flushBridgePendingToDb(state, sessionId, runMarker)
const toolName = (ev.tool_name as string) || ''
const args = ev.args as Record<string, unknown> | undefined
@@ -498,6 +567,11 @@ async function applyBridgeChunkAsync(
const compressionResult = ev.request_id
? state.bridgeCompressionResults?.[String(ev.request_id)]
: undefined
const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens)
const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null
? contextTokensWithCachedOverhead(state, messageAfterTokens)
: bridgeAfterContextTokens ?? messageAfterTokens
const payload = {
event: 'compression.completed',
run_id: chunk.run_id,
@@ -507,9 +581,8 @@ async function applyBridgeChunkAsync(
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
afterTokens: typeof ev.result_approx_tokens === 'number' && Number.isFinite(ev.result_approx_tokens) && ev.result_approx_tokens > 0
? ev.result_approx_tokens
: compressionResult?.afterTokens,
afterTokens: messageAfterTokens ?? bridgeAfterContextTokens,
contextTokens: afterContextTokens,
summaryTokens: compressionResult?.summaryTokens,
verbatimCount: compressionResult?.verbatimCount,
compressedStartIndex: compressionResult?.compressedStartIndex,
@@ -520,7 +593,12 @@ async function applyBridgeChunkAsync(
}
replaceState(sessionMap, sessionId, 'compression.completed', payload)
emit('compression.completed', payload)
await calcAndUpdateUsage(sessionId, state, emit)
const usage = await calcAndUpdateUsage(sessionId, state, emit)
if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) {
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage)
} else {
updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage)
}
} else if (evType === 'bridge.compression.failed') {
const payload = {
event: 'compression.completed',