fix compression context usage accounting (#924)
This commit is contained in:
@@ -8,10 +8,17 @@ import { getSystemPrompt } from '../../../lib/llm-prompt'
|
||||
import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store'
|
||||
import { updateUsage } from '../../../db/hermes/usage-store'
|
||||
import { logger, bridgeLogger } from '../../logger'
|
||||
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
|
||||
import { AgentBridgeClient, type AgentBridgeContextEstimate, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
|
||||
import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks'
|
||||
import { buildCompressedHistory, buildDbHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression'
|
||||
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||
import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression'
|
||||
import {
|
||||
calcAndUpdateUsage,
|
||||
contextTokensWithCachedOverhead,
|
||||
estimateUsageTokensFromMessages,
|
||||
getCachedBridgeContextOverhead,
|
||||
updateContextTokenUsage,
|
||||
updateMessageContextTokenUsage,
|
||||
} from './usage'
|
||||
import {
|
||||
flushBridgePendingToDb,
|
||||
ensureOpenBridgeAssistantMessage,
|
||||
@@ -62,6 +69,28 @@ export function bridgeTerminalError(chunk: Pick<AgentBridgeOutput, 'status' | 'e
|
||||
return null
|
||||
}
|
||||
|
||||
function finiteToken(value: unknown): number | undefined {
|
||||
return typeof value === 'number' && Number.isFinite(value) && value >= 0
|
||||
? Math.floor(value)
|
||||
: undefined
|
||||
}
|
||||
|
||||
function cacheBridgeContext(state: SessionState, data: Record<string, unknown> | AgentBridgeContextEstimate) {
|
||||
const fixedContextTokens = finiteToken(data.fixed_context_tokens)
|
||||
if (fixedContextTokens == null) return
|
||||
state.bridgeContext = {
|
||||
fixedContextTokens,
|
||||
systemPromptTokens: finiteToken(data.system_prompt_tokens),
|
||||
toolTokens: finiteToken(data.tool_tokens),
|
||||
systemPromptChars: finiteToken(data.system_prompt_chars),
|
||||
toolCount: finiteToken(data.tool_count),
|
||||
toolNames: Array.isArray(data.tool_names) ? data.tool_names.map(String) : undefined,
|
||||
profile: typeof data.profile === 'string' ? data.profile : state.bridgeContext?.profile,
|
||||
model: typeof data.model === 'string' ? data.model : state.bridgeContext?.model,
|
||||
provider: typeof data.provider === 'string' ? data.provider : state.bridgeContext?.provider,
|
||||
}
|
||||
}
|
||||
|
||||
export async function handleBridgeRun(
|
||||
nsp: ReturnType<Server['of']>,
|
||||
socket: Socket,
|
||||
@@ -168,6 +197,11 @@ export async function handleBridgeRun(
|
||||
sessionMap,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
async (messages) => {
|
||||
const cachedOverhead = getCachedBridgeContextOverhead(state)
|
||||
if (cachedOverhead != null) {
|
||||
const messageUsage = estimateUsageTokensFromMessages(messages)
|
||||
return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens
|
||||
}
|
||||
const estimate = await bridge.contextEstimate(
|
||||
session_id,
|
||||
messages,
|
||||
@@ -175,6 +209,7 @@ export async function handleBridgeRun(
|
||||
profile,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
)
|
||||
cacheBridgeContext(state, estimate)
|
||||
bridgeLogger.info({
|
||||
sessionId: session_id,
|
||||
profile,
|
||||
@@ -183,6 +218,7 @@ export async function handleBridgeRun(
|
||||
messages: estimate.message_count,
|
||||
toolCount: estimate.tool_count,
|
||||
systemPromptChars: estimate.system_prompt_chars,
|
||||
fixedContextTokens: estimate.fixed_context_tokens,
|
||||
fullContextTokens: estimate.token_count,
|
||||
}, '[chat-run-socket] full context estimate')
|
||||
return estimate.token_count
|
||||
@@ -308,7 +344,35 @@ async function refreshFinalContextUsage(args: {
|
||||
bridge: AgentBridgeClient
|
||||
}): Promise<number | undefined> {
|
||||
try {
|
||||
const finalHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
|
||||
const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
|
||||
const finalHistory = await buildSnapshotAwareHistory(
|
||||
args.sessionId,
|
||||
args.profile,
|
||||
dbHistory,
|
||||
{ model: args.model, provider: args.provider },
|
||||
)
|
||||
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
|
||||
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
|
||||
if (getCachedBridgeContextOverhead(args.state) != null) {
|
||||
const contextTokens = updateMessageContextTokenUsage(
|
||||
args.sessionId,
|
||||
args.state,
|
||||
args.emit,
|
||||
finalMessageTokens,
|
||||
args.usage,
|
||||
)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
messages: finalHistory.length,
|
||||
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
|
||||
messageTokens: finalMessageTokens,
|
||||
fullContextTokens: contextTokens,
|
||||
}, '[chat-run-socket] final cached context estimate')
|
||||
return contextTokens
|
||||
}
|
||||
const estimate = await args.bridge.contextEstimate(
|
||||
args.sessionId,
|
||||
finalHistory,
|
||||
@@ -316,18 +380,13 @@ async function refreshFinalContextUsage(args: {
|
||||
args.profile,
|
||||
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
|
||||
)
|
||||
cacheBridgeContext(args.state, estimate)
|
||||
const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0
|
||||
? Math.floor(estimate.token_count)
|
||||
: undefined
|
||||
if (contextTokens == null) return args.state.contextTokens
|
||||
|
||||
args.state.contextTokens = contextTokens
|
||||
args.emit('usage.updated', {
|
||||
event: 'usage.updated',
|
||||
inputTokens: args.usage.inputTokens,
|
||||
outputTokens: args.usage.outputTokens,
|
||||
contextTokens,
|
||||
})
|
||||
updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
@@ -378,7 +437,17 @@ async function applyBridgeChunkAsync(
|
||||
|
||||
for (const ev of chunk.events || []) {
|
||||
const evType = ev.event as string | undefined
|
||||
if (evType === 'tool.started') {
|
||||
if (evType === 'bridge.context.ready') {
|
||||
cacheBridgeContext(state, ev)
|
||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||
updateMessageContextTokenUsage(
|
||||
sessionId,
|
||||
state,
|
||||
emit,
|
||||
usage.inputTokens + usage.outputTokens,
|
||||
usage,
|
||||
)
|
||||
} else if (evType === 'tool.started') {
|
||||
flushBridgePendingToDb(state, sessionId, runMarker)
|
||||
const toolName = (ev.tool_name as string) || ''
|
||||
const args = ev.args as Record<string, unknown> | undefined
|
||||
@@ -498,6 +567,11 @@ async function applyBridgeChunkAsync(
|
||||
const compressionResult = ev.request_id
|
||||
? state.bridgeCompressionResults?.[String(ev.request_id)]
|
||||
: undefined
|
||||
const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens)
|
||||
const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
|
||||
const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null
|
||||
? contextTokensWithCachedOverhead(state, messageAfterTokens)
|
||||
: bridgeAfterContextTokens ?? messageAfterTokens
|
||||
const payload = {
|
||||
event: 'compression.completed',
|
||||
run_id: chunk.run_id,
|
||||
@@ -507,9 +581,8 @@ async function applyBridgeChunkAsync(
|
||||
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
|
||||
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
|
||||
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
|
||||
afterTokens: typeof ev.result_approx_tokens === 'number' && Number.isFinite(ev.result_approx_tokens) && ev.result_approx_tokens > 0
|
||||
? ev.result_approx_tokens
|
||||
: compressionResult?.afterTokens,
|
||||
afterTokens: messageAfterTokens ?? bridgeAfterContextTokens,
|
||||
contextTokens: afterContextTokens,
|
||||
summaryTokens: compressionResult?.summaryTokens,
|
||||
verbatimCount: compressionResult?.verbatimCount,
|
||||
compressedStartIndex: compressionResult?.compressedStartIndex,
|
||||
@@ -520,7 +593,12 @@ async function applyBridgeChunkAsync(
|
||||
}
|
||||
replaceState(sessionMap, sessionId, 'compression.completed', payload)
|
||||
emit('compression.completed', payload)
|
||||
await calcAndUpdateUsage(sessionId, state, emit)
|
||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||
if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) {
|
||||
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage)
|
||||
} else {
|
||||
updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage)
|
||||
}
|
||||
} else if (evType === 'bridge.compression.failed') {
|
||||
const payload = {
|
||||
event: 'compression.completed',
|
||||
|
||||
Reference in New Issue
Block a user