align compression token estimates (#749)

This commit is contained in:
ekko
2026-05-15 13:50:27 +08:00
committed by GitHub
parent 6c80254dd3
commit 3d49f778fb
7 changed files with 89 additions and 28 deletions
@@ -7,11 +7,11 @@ import {
getSessionDetail,
} from '../../../db/hermes/session-store'
import { getCompressionSnapshot } from '../../../db/hermes/compression-snapshot'
import { ChatContextCompressor, countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
import { ChatContextCompressor, SUMMARY_PREFIX } from '../../../lib/context-compressor'
import { getModelContextLength } from '../model-context'
import { logger } from '../../logger'
import { bridgeLogger } from '../../logger'
import { calcAndUpdateUsage } from './usage'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
import type { ChatMessage } from '../../../lib/context-compressor'
import type { SessionState, BridgeCompressionResult } from './types'
@@ -210,7 +210,8 @@ export async function forceCompressBridgeHistory(
const upstream = getUpstream(profile).replace(/\/$/, '')
const apiKey = getApiKey(profile) || undefined
const totalTokens = countTokens(JSON.stringify(history))
const beforeUsage = estimateUsageTokensFromMessages(history)
const totalTokens = beforeUsage.inputTokens + beforeUsage.outputTokens
bridgeLogger.info({
sessionId,
profile,
@@ -234,7 +235,8 @@ export async function forceCompressBridgeHistory(
if (m.name) msg.name = m.name
return msg
})
const afterTokens = countTokens(JSON.stringify(compressedMessages))
const afterUsage = estimateUsageTokensFromMessages(compressedMessages)
const afterTokens = afterUsage.inputTokens + afterUsage.outputTokens
bridgeLogger.info({
sessionId,
profile,
@@ -19,7 +19,7 @@ import { readSseFrames } from './sse-utils'
import { extractResponseText } from './response-utils'
import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream'
import { buildCompressedHistory, getOrCreateSession } from './compression'
import { calcAndUpdateUsage } from './usage'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
import { handleMessage } from './message-format'
import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
import { getCompressionSnapshot } from '../../../db/hermes/compression-snapshot'
@@ -47,16 +47,14 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
const snapshot = getCompressionSnapshot(sid)
if (snapshot) {
const newMessages = messages.slice(snapshot.lastMessageIndex + 1)
const newUsage = estimateUsageTokensFromMessages(newMessages)
inputTokens = countTokens(SUMMARY_PREFIX + snapshot.summary) +
newMessages.filter(m => m.role === 'user').reduce((sum, m) => sum + countTokens(m.content || ''), 0)
outputTokens = newMessages
.filter(m => m.role === 'assistant' || m.role === 'tool')
.reduce((sum, m) => sum + countTokens(m.content || '') + countTokens(m.tool_calls + '' || ''), 0)
newUsage.inputTokens
outputTokens = newUsage.outputTokens
} else {
inputTokens = messages.filter(m => m.role === 'user').reduce((sum, m) => sum + countTokens(m.content || ''), 0)
outputTokens = messages
.filter(m => m.role === 'assistant' || m.role === 'tool')
.reduce((sum, m) => sum + countTokens(m.content || '') + countTokens(m.tool_calls + '' || ''), 0)
const usage = estimateUsageTokensFromMessages(messages)
inputTokens = usage.inputTokens
outputTokens = usage.outputTokens
}
logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length)
@@ -6,13 +6,12 @@
import type { Server, Socket } from 'socket.io'
import { getSession, createSession, addMessage, updateSessionStats } from '../../../db/hermes/session-store'
import { updateUsage } from '../../../db/hermes/usage-store'
import { countTokens } from '../../../lib/context-compressor'
import { logger, bridgeLogger } from '../../logger'
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
import { contentBlocksToString, extractTextForPreview } from './content-blocks'
import { buildCompressedHistory } from './compression'
import { pushState, replaceState } from './compression'
import { calcAndUpdateUsage } from './usage'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
import {
flushBridgePendingToDb,
ensureOpenBridgeAssistantMessage,
@@ -267,8 +266,9 @@ async function applyBridgeChunkAsync(
emit('approval.resolved', payload)
} else if (evType === 'bridge.compression.requested') {
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
const tokenCount = bridgeHistory.length > 0
? countTokens(JSON.stringify(bridgeHistory))
? bridgeUsage.inputTokens + bridgeUsage.outputTokens
: ev.approx_tokens
const payload = {
event: 'compression.started',
@@ -5,8 +5,7 @@ import type { AgentBridgeClient } from '../agent-bridge'
import { flushBridgePendingToDb } from './bridge-message'
import { buildDbHistory, forceCompressBridgeHistory, getOrCreateSession, replaceState } from './compression'
import { handleAbort } from './abort'
import { calcAndUpdateUsage } from './usage'
import { countTokens } from '../../../lib/context-compressor'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
import type { ContentBlock, QueuedRun, SessionState } from './types'
type CommandName =
@@ -233,7 +232,8 @@ export async function handleSessionCommand(
const emit = (event: string, payload: any) => emitToSession(ctx.nsp, ctx.socket, sessionId, event, payload)
try {
const history = await buildDbHistory(sessionId, { excludeLastUser: true })
const tokenEstimate = history.length > 0 ? countTokens(JSON.stringify(history)) : 0
const usageEstimate = estimateUsageTokensFromMessages(history)
const tokenEstimate = usageEstimate.inputTokens + usageEstimate.outputTokens
emit('compression.started', {
event: 'compression.started',
message_count: history.length,
@@ -11,6 +11,35 @@ import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
import { logger } from '../../logger'
import type { SessionState } from './types'
type UsageTokenMessage = {
role?: string
content?: unknown
tool_calls?: unknown
}
function contentToUsageText(content: unknown): string {
if (typeof content === 'string') return content
if (!content) return ''
if (Array.isArray(content)) {
return content.map((block: any) => {
if (typeof block?.text === 'string') return block.text
if (typeof block?.type === 'string') return `[${block.type}]`
return String(block || '')
}).join('\n')
}
return String(content)
}
export function estimateUsageTokensFromMessages(messages: UsageTokenMessage[]): { inputTokens: number; outputTokens: number } {
const inputTokens = messages
.filter(m => m.role === 'user')
.reduce((sum, m) => sum + countTokens(contentToUsageText(m.content)), 0)
const outputTokens = messages
.filter(m => m.role === 'assistant' || m.role === 'tool')
.reduce((sum, m) => sum + countTokens(contentToUsageText(m.content)) + countTokens(String(m.tool_calls || '')), 0)
return { inputTokens, outputTokens }
}
export async function calcAndUpdateUsage(
sid: string,
state: SessionState,
@@ -26,16 +55,14 @@ export async function calcAndUpdateUsage(
let outputTokens: number
if (snapshot && msgs.length) {
const newMessages = msgs.slice(snapshot.lastMessageIndex + 1)
const newUsage = estimateUsageTokensFromMessages(newMessages)
inputTokens = countTokens(SUMMARY_PREFIX + snapshot.summary) +
newMessages.filter(m => m.role === 'user').reduce((sum, m) => sum + countTokens(m.content || ''), 0)
outputTokens = newMessages
.filter(m => m.role === 'assistant' || m.role === 'tool')
.reduce((sum, m) => sum + countTokens(m.content || '') + countTokens(m.tool_calls + '' || ''), 0)
newUsage.inputTokens
outputTokens = newUsage.outputTokens
} else {
inputTokens = msgs.filter(m => m.role === 'user').reduce((sum, m) => sum + countTokens(m.content || ''), 0)
outputTokens = msgs
.filter(m => m.role === 'assistant' || m.role === 'tool')
.reduce((sum, m) => sum + countTokens(m.content || '') + countTokens(m.tool_calls + '' || ''), 0)
const usage = estimateUsageTokensFromMessages(msgs)
inputTokens = usage.inputTokens
outputTokens = usage.outputTokens
}
state.inputTokens = inputTokens
state.outputTokens = outputTokens