align compression token estimates (#749)
This commit is contained in:
@@ -845,6 +845,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
// Capture session ID at send time — all callbacks use this, not activeSessionId
|
// Capture session ID at send time — all callbacks use this, not activeSessionId
|
||||||
const sid = activeSessionId.value!
|
const sid = activeSessionId.value!
|
||||||
const isBridgeSlashCommand = activeSession.value?.source === 'cli' && content.trim().startsWith('/')
|
const isBridgeSlashCommand = activeSession.value?.source === 'cli' && content.trim().startsWith('/')
|
||||||
|
const isBridgeCompressCommand = isBridgeSlashCommand && /^\/compress(?:\s|$)/i.test(content.trim())
|
||||||
const wasLiveBeforeSend = isSessionLive(sid)
|
const wasLiveBeforeSend = isSessionLive(sid)
|
||||||
const shouldQueue = wasLiveBeforeSend && !isBridgeSlashCommand
|
const shouldQueue = wasLiveBeforeSend && !isBridgeSlashCommand
|
||||||
|
|
||||||
@@ -1348,7 +1349,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
undefined,
|
undefined,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!isBridgeSlashCommand || !wasLiveBeforeSend) {
|
if (!isBridgeSlashCommand || isBridgeCompressCommand) {
|
||||||
streamStates.value.set(sid, ctrl)
|
streamStates.value.set(sid, ctrl)
|
||||||
}
|
}
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ import {
|
|||||||
getSessionDetail,
|
getSessionDetail,
|
||||||
} from '../../../db/hermes/session-store'
|
} from '../../../db/hermes/session-store'
|
||||||
import { getCompressionSnapshot } from '../../../db/hermes/compression-snapshot'
|
import { getCompressionSnapshot } from '../../../db/hermes/compression-snapshot'
|
||||||
import { ChatContextCompressor, countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
|
import { ChatContextCompressor, SUMMARY_PREFIX } from '../../../lib/context-compressor'
|
||||||
import { getModelContextLength } from '../model-context'
|
import { getModelContextLength } from '../model-context'
|
||||||
import { logger } from '../../logger'
|
import { logger } from '../../logger'
|
||||||
import { bridgeLogger } from '../../logger'
|
import { bridgeLogger } from '../../logger'
|
||||||
import { calcAndUpdateUsage } from './usage'
|
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||||
import type { ChatMessage } from '../../../lib/context-compressor'
|
import type { ChatMessage } from '../../../lib/context-compressor'
|
||||||
import type { SessionState, BridgeCompressionResult } from './types'
|
import type { SessionState, BridgeCompressionResult } from './types'
|
||||||
|
|
||||||
@@ -210,7 +210,8 @@ export async function forceCompressBridgeHistory(
|
|||||||
|
|
||||||
const upstream = getUpstream(profile).replace(/\/$/, '')
|
const upstream = getUpstream(profile).replace(/\/$/, '')
|
||||||
const apiKey = getApiKey(profile) || undefined
|
const apiKey = getApiKey(profile) || undefined
|
||||||
const totalTokens = countTokens(JSON.stringify(history))
|
const beforeUsage = estimateUsageTokensFromMessages(history)
|
||||||
|
const totalTokens = beforeUsage.inputTokens + beforeUsage.outputTokens
|
||||||
bridgeLogger.info({
|
bridgeLogger.info({
|
||||||
sessionId,
|
sessionId,
|
||||||
profile,
|
profile,
|
||||||
@@ -234,7 +235,8 @@ export async function forceCompressBridgeHistory(
|
|||||||
if (m.name) msg.name = m.name
|
if (m.name) msg.name = m.name
|
||||||
return msg
|
return msg
|
||||||
})
|
})
|
||||||
const afterTokens = countTokens(JSON.stringify(compressedMessages))
|
const afterUsage = estimateUsageTokensFromMessages(compressedMessages)
|
||||||
|
const afterTokens = afterUsage.inputTokens + afterUsage.outputTokens
|
||||||
bridgeLogger.info({
|
bridgeLogger.info({
|
||||||
sessionId,
|
sessionId,
|
||||||
profile,
|
profile,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import { readSseFrames } from './sse-utils'
|
|||||||
import { extractResponseText } from './response-utils'
|
import { extractResponseText } from './response-utils'
|
||||||
import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream'
|
import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream'
|
||||||
import { buildCompressedHistory, getOrCreateSession } from './compression'
|
import { buildCompressedHistory, getOrCreateSession } from './compression'
|
||||||
import { calcAndUpdateUsage } from './usage'
|
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||||
import { handleMessage } from './message-format'
|
import { handleMessage } from './message-format'
|
||||||
import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
|
import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
|
||||||
import { getCompressionSnapshot } from '../../../db/hermes/compression-snapshot'
|
import { getCompressionSnapshot } from '../../../db/hermes/compression-snapshot'
|
||||||
@@ -47,16 +47,14 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
|||||||
const snapshot = getCompressionSnapshot(sid)
|
const snapshot = getCompressionSnapshot(sid)
|
||||||
if (snapshot) {
|
if (snapshot) {
|
||||||
const newMessages = messages.slice(snapshot.lastMessageIndex + 1)
|
const newMessages = messages.slice(snapshot.lastMessageIndex + 1)
|
||||||
|
const newUsage = estimateUsageTokensFromMessages(newMessages)
|
||||||
inputTokens = countTokens(SUMMARY_PREFIX + snapshot.summary) +
|
inputTokens = countTokens(SUMMARY_PREFIX + snapshot.summary) +
|
||||||
newMessages.filter(m => m.role === 'user').reduce((sum, m) => sum + countTokens(m.content || ''), 0)
|
newUsage.inputTokens
|
||||||
outputTokens = newMessages
|
outputTokens = newUsage.outputTokens
|
||||||
.filter(m => m.role === 'assistant' || m.role === 'tool')
|
|
||||||
.reduce((sum, m) => sum + countTokens(m.content || '') + countTokens(m.tool_calls + '' || ''), 0)
|
|
||||||
} else {
|
} else {
|
||||||
inputTokens = messages.filter(m => m.role === 'user').reduce((sum, m) => sum + countTokens(m.content || ''), 0)
|
const usage = estimateUsageTokensFromMessages(messages)
|
||||||
outputTokens = messages
|
inputTokens = usage.inputTokens
|
||||||
.filter(m => m.role === 'assistant' || m.role === 'tool')
|
outputTokens = usage.outputTokens
|
||||||
.reduce((sum, m) => sum + countTokens(m.content || '') + countTokens(m.tool_calls + '' || ''), 0)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length)
|
logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length)
|
||||||
|
|||||||
@@ -6,13 +6,12 @@
|
|||||||
import type { Server, Socket } from 'socket.io'
|
import type { Server, Socket } from 'socket.io'
|
||||||
import { getSession, createSession, addMessage, updateSessionStats } from '../../../db/hermes/session-store'
|
import { getSession, createSession, addMessage, updateSessionStats } from '../../../db/hermes/session-store'
|
||||||
import { updateUsage } from '../../../db/hermes/usage-store'
|
import { updateUsage } from '../../../db/hermes/usage-store'
|
||||||
import { countTokens } from '../../../lib/context-compressor'
|
|
||||||
import { logger, bridgeLogger } from '../../logger'
|
import { logger, bridgeLogger } from '../../logger'
|
||||||
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
|
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
|
||||||
import { contentBlocksToString, extractTextForPreview } from './content-blocks'
|
import { contentBlocksToString, extractTextForPreview } from './content-blocks'
|
||||||
import { buildCompressedHistory } from './compression'
|
import { buildCompressedHistory } from './compression'
|
||||||
import { pushState, replaceState } from './compression'
|
import { pushState, replaceState } from './compression'
|
||||||
import { calcAndUpdateUsage } from './usage'
|
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||||
import {
|
import {
|
||||||
flushBridgePendingToDb,
|
flushBridgePendingToDb,
|
||||||
ensureOpenBridgeAssistantMessage,
|
ensureOpenBridgeAssistantMessage,
|
||||||
@@ -267,8 +266,9 @@ async function applyBridgeChunkAsync(
|
|||||||
emit('approval.resolved', payload)
|
emit('approval.resolved', payload)
|
||||||
} else if (evType === 'bridge.compression.requested') {
|
} else if (evType === 'bridge.compression.requested') {
|
||||||
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
|
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||||
|
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
|
||||||
const tokenCount = bridgeHistory.length > 0
|
const tokenCount = bridgeHistory.length > 0
|
||||||
? countTokens(JSON.stringify(bridgeHistory))
|
? bridgeUsage.inputTokens + bridgeUsage.outputTokens
|
||||||
: ev.approx_tokens
|
: ev.approx_tokens
|
||||||
const payload = {
|
const payload = {
|
||||||
event: 'compression.started',
|
event: 'compression.started',
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import type { AgentBridgeClient } from '../agent-bridge'
|
|||||||
import { flushBridgePendingToDb } from './bridge-message'
|
import { flushBridgePendingToDb } from './bridge-message'
|
||||||
import { buildDbHistory, forceCompressBridgeHistory, getOrCreateSession, replaceState } from './compression'
|
import { buildDbHistory, forceCompressBridgeHistory, getOrCreateSession, replaceState } from './compression'
|
||||||
import { handleAbort } from './abort'
|
import { handleAbort } from './abort'
|
||||||
import { calcAndUpdateUsage } from './usage'
|
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||||
import { countTokens } from '../../../lib/context-compressor'
|
|
||||||
import type { ContentBlock, QueuedRun, SessionState } from './types'
|
import type { ContentBlock, QueuedRun, SessionState } from './types'
|
||||||
|
|
||||||
type CommandName =
|
type CommandName =
|
||||||
@@ -233,7 +232,8 @@ export async function handleSessionCommand(
|
|||||||
const emit = (event: string, payload: any) => emitToSession(ctx.nsp, ctx.socket, sessionId, event, payload)
|
const emit = (event: string, payload: any) => emitToSession(ctx.nsp, ctx.socket, sessionId, event, payload)
|
||||||
try {
|
try {
|
||||||
const history = await buildDbHistory(sessionId, { excludeLastUser: true })
|
const history = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||||
const tokenEstimate = history.length > 0 ? countTokens(JSON.stringify(history)) : 0
|
const usageEstimate = estimateUsageTokensFromMessages(history)
|
||||||
|
const tokenEstimate = usageEstimate.inputTokens + usageEstimate.outputTokens
|
||||||
emit('compression.started', {
|
emit('compression.started', {
|
||||||
event: 'compression.started',
|
event: 'compression.started',
|
||||||
message_count: history.length,
|
message_count: history.length,
|
||||||
|
|||||||
@@ -11,6 +11,35 @@ import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
|
|||||||
import { logger } from '../../logger'
|
import { logger } from '../../logger'
|
||||||
import type { SessionState } from './types'
|
import type { SessionState } from './types'
|
||||||
|
|
||||||
|
type UsageTokenMessage = {
|
||||||
|
role?: string
|
||||||
|
content?: unknown
|
||||||
|
tool_calls?: unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
function contentToUsageText(content: unknown): string {
|
||||||
|
if (typeof content === 'string') return content
|
||||||
|
if (!content) return ''
|
||||||
|
if (Array.isArray(content)) {
|
||||||
|
return content.map((block: any) => {
|
||||||
|
if (typeof block?.text === 'string') return block.text
|
||||||
|
if (typeof block?.type === 'string') return `[${block.type}]`
|
||||||
|
return String(block || '')
|
||||||
|
}).join('\n')
|
||||||
|
}
|
||||||
|
return String(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function estimateUsageTokensFromMessages(messages: UsageTokenMessage[]): { inputTokens: number; outputTokens: number } {
|
||||||
|
const inputTokens = messages
|
||||||
|
.filter(m => m.role === 'user')
|
||||||
|
.reduce((sum, m) => sum + countTokens(contentToUsageText(m.content)), 0)
|
||||||
|
const outputTokens = messages
|
||||||
|
.filter(m => m.role === 'assistant' || m.role === 'tool')
|
||||||
|
.reduce((sum, m) => sum + countTokens(contentToUsageText(m.content)) + countTokens(String(m.tool_calls || '')), 0)
|
||||||
|
return { inputTokens, outputTokens }
|
||||||
|
}
|
||||||
|
|
||||||
export async function calcAndUpdateUsage(
|
export async function calcAndUpdateUsage(
|
||||||
sid: string,
|
sid: string,
|
||||||
state: SessionState,
|
state: SessionState,
|
||||||
@@ -26,16 +55,14 @@ export async function calcAndUpdateUsage(
|
|||||||
let outputTokens: number
|
let outputTokens: number
|
||||||
if (snapshot && msgs.length) {
|
if (snapshot && msgs.length) {
|
||||||
const newMessages = msgs.slice(snapshot.lastMessageIndex + 1)
|
const newMessages = msgs.slice(snapshot.lastMessageIndex + 1)
|
||||||
|
const newUsage = estimateUsageTokensFromMessages(newMessages)
|
||||||
inputTokens = countTokens(SUMMARY_PREFIX + snapshot.summary) +
|
inputTokens = countTokens(SUMMARY_PREFIX + snapshot.summary) +
|
||||||
newMessages.filter(m => m.role === 'user').reduce((sum, m) => sum + countTokens(m.content || ''), 0)
|
newUsage.inputTokens
|
||||||
outputTokens = newMessages
|
outputTokens = newUsage.outputTokens
|
||||||
.filter(m => m.role === 'assistant' || m.role === 'tool')
|
|
||||||
.reduce((sum, m) => sum + countTokens(m.content || '') + countTokens(m.tool_calls + '' || ''), 0)
|
|
||||||
} else {
|
} else {
|
||||||
inputTokens = msgs.filter(m => m.role === 'user').reduce((sum, m) => sum + countTokens(m.content || ''), 0)
|
const usage = estimateUsageTokensFromMessages(msgs)
|
||||||
outputTokens = msgs
|
inputTokens = usage.inputTokens
|
||||||
.filter(m => m.role === 'assistant' || m.role === 'tool')
|
outputTokens = usage.outputTokens
|
||||||
.reduce((sum, m) => sum + countTokens(m.content || '') + countTokens(m.tool_calls + '' || ''), 0)
|
|
||||||
}
|
}
|
||||||
state.inputTokens = inputTokens
|
state.inputTokens = inputTokens
|
||||||
state.outputTokens = outputTokens
|
state.outputTokens = outputTokens
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
import { countTokens } from '../../packages/server/src/lib/context-compressor'
|
||||||
|
import { estimateUsageTokensFromMessages } from '../../packages/server/src/services/hermes/run-chat/usage'
|
||||||
|
|
||||||
|
describe('run-chat usage token estimates', () => {
|
||||||
|
it('counts message content instead of serialized message payloads', () => {
|
||||||
|
const messages = [
|
||||||
|
{ role: 'user', content: 'hello from user' },
|
||||||
|
{ role: 'assistant', content: 'hello from assistant' },
|
||||||
|
]
|
||||||
|
|
||||||
|
const usage = estimateUsageTokensFromMessages(messages)
|
||||||
|
|
||||||
|
expect(usage.inputTokens).toBe(countTokens('hello from user'))
|
||||||
|
expect(usage.outputTokens).toBe(countTokens('hello from assistant'))
|
||||||
|
expect(usage.inputTokens + usage.outputTokens).toBeLessThan(countTokens(JSON.stringify(messages)))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('keeps assistant tool call tokens on the output side', () => {
|
||||||
|
const messages = [
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: 'calling tool',
|
||||||
|
tool_calls: [{ id: 'call_1', type: 'function', function: { name: 'lookup', arguments: '{"q":"x"}' } }],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const usage = estimateUsageTokensFromMessages(messages)
|
||||||
|
|
||||||
|
expect(usage.inputTokens).toBe(0)
|
||||||
|
expect(usage.outputTokens).toBe(countTokens('calling tool') + countTokens(String(messages[0].tool_calls || '')))
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user