fix compression context usage accounting (#924)
This commit is contained in:
@@ -569,14 +569,16 @@ export const useChatStore = defineStore('chat', () => {
|
||||
compressed: null,
|
||||
})
|
||||
} else if (e.event === 'compression.completed') {
|
||||
const afterTokens = e.contextTokens || e.afterTokens || 0
|
||||
setCompressionState({
|
||||
compressing: false,
|
||||
messageCount: e.totalMessages || 0,
|
||||
beforeTokens: e.beforeTokens || 0,
|
||||
afterTokens: e.afterTokens || 0,
|
||||
afterTokens,
|
||||
compressed: e.compressed ?? false,
|
||||
error: e.error,
|
||||
})
|
||||
if (e.contextTokens != null) activeSession.value!.contextTokens = e.contextTokens
|
||||
} else if (e.event === 'abort.started') {
|
||||
setAbortState({ aborting: true, synced: null })
|
||||
} else if (e.event === 'abort.completed') {
|
||||
@@ -1073,14 +1075,19 @@ export const useChatStore = defineStore('chat', () => {
|
||||
}
|
||||
|
||||
case 'compression.completed': {
|
||||
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
|
||||
setCompressionState({
|
||||
compressing: false,
|
||||
messageCount: (evt as any).totalMessages || 0,
|
||||
beforeTokens: (evt as any).beforeTokens || 0,
|
||||
afterTokens: (evt as any).afterTokens || 0,
|
||||
afterTokens,
|
||||
compressed: (evt as any).compressed ?? false,
|
||||
error: (evt as any).error,
|
||||
})
|
||||
if ((evt as any).contextTokens != null) {
|
||||
const target = sessions.value.find(s => s.id === sid)
|
||||
if (target) target.contextTokens = (evt as any).contextTokens
|
||||
}
|
||||
// Auto-clear after 5s
|
||||
setTimeout(() => {
|
||||
if (compressionState.value && !compressionState.value.compressing) {
|
||||
@@ -1520,14 +1527,19 @@ export const useChatStore = defineStore('chat', () => {
|
||||
}
|
||||
|
||||
case 'compression.completed': {
|
||||
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
|
||||
setCompressionState({
|
||||
compressing: false,
|
||||
messageCount: (evt as any).totalMessages || 0,
|
||||
beforeTokens: (evt as any).beforeTokens || 0,
|
||||
afterTokens: (evt as any).afterTokens || 0,
|
||||
afterTokens,
|
||||
compressed: (evt as any).compressed ?? false,
|
||||
error: (evt as any).error,
|
||||
})
|
||||
if ((evt as any).contextTokens != null) {
|
||||
const target = sessions.value.find(s => s.id === sid)
|
||||
if (target) target.contextTokens = (evt as any).contextTokens
|
||||
}
|
||||
setTimeout(() => {
|
||||
if (compressionState.value && !compressionState.value.compressing) {
|
||||
setCompressionState(null)
|
||||
|
||||
@@ -93,9 +93,16 @@ export interface AgentBridgeRunResult extends AgentBridgeResponse {
|
||||
export interface AgentBridgeContextEstimate extends AgentBridgeResponse {
|
||||
session_id: string
|
||||
token_count?: number | null
|
||||
fixed_context_tokens?: number | null
|
||||
system_prompt_tokens?: number | null
|
||||
tool_tokens?: number | null
|
||||
message_count: number
|
||||
tool_count: number
|
||||
tool_names?: string[]
|
||||
system_prompt_chars: number
|
||||
profile?: string
|
||||
model?: string
|
||||
provider?: string
|
||||
}
|
||||
|
||||
export interface AgentBridgeCommandResult extends AgentBridgeResponse {
|
||||
|
||||
@@ -668,30 +668,79 @@ class AgentPool:
|
||||
|
||||
agent._compress_context = wrapped_compress_context
|
||||
|
||||
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
|
||||
def _agent_system_prompt(self, agent: Any, system_message: Any = None) -> str:
|
||||
prompt = str(getattr(agent, "_cached_system_prompt", "") or "")
|
||||
if prompt:
|
||||
return prompt
|
||||
try:
|
||||
build_prompt = getattr(agent, "_build_system_prompt", None)
|
||||
if callable(build_prompt):
|
||||
return str(build_prompt(system_message) or "")
|
||||
except Exception:
|
||||
return str(system_message or "")
|
||||
return str(system_message or "")
|
||||
|
||||
def _agent_tool_names(self, tools: Any) -> list[str]:
|
||||
if not isinstance(tools, list):
|
||||
return []
|
||||
names: list[str] = []
|
||||
for tool in tools:
|
||||
name = ""
|
||||
if isinstance(tool, dict):
|
||||
function = tool.get("function")
|
||||
if isinstance(function, dict):
|
||||
name = str(function.get("name") or "")
|
||||
if not name:
|
||||
name = str(tool.get("name") or "")
|
||||
else:
|
||||
name = str(getattr(tool, "name", "") or "")
|
||||
if name:
|
||||
names.append(name)
|
||||
return names
|
||||
|
||||
def _estimate_context_info(self, agent: Any, messages: Any, system_message: Any = None) -> dict[str, Any]:
|
||||
try:
|
||||
from agent.model_metadata import estimate_request_tokens_rough
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
prompt = str(getattr(agent, "_cached_system_prompt", "") or "")
|
||||
if not prompt:
|
||||
try:
|
||||
build_prompt = getattr(agent, "_build_system_prompt", None)
|
||||
if callable(build_prompt):
|
||||
prompt = str(build_prompt(system_message) or "")
|
||||
except Exception:
|
||||
prompt = str(system_message or "")
|
||||
return {}
|
||||
|
||||
prompt = self._agent_system_prompt(agent, system_message)
|
||||
tools = getattr(agent, "tools", None) or []
|
||||
message_list = messages if isinstance(messages, list) else []
|
||||
try:
|
||||
estimate = estimate_request_tokens_rough(
|
||||
messages if isinstance(messages, list) else [],
|
||||
system_prompt=prompt,
|
||||
tools=getattr(agent, "tools", None) or None,
|
||||
)
|
||||
return int(estimate) if isinstance(estimate, (int, float)) and estimate > 0 else None
|
||||
token_count = estimate_request_tokens_rough(message_list, system_prompt=prompt, tools=tools or None)
|
||||
fixed_context_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=tools or None)
|
||||
system_prompt_tokens = estimate_request_tokens_rough([], system_prompt=prompt, tools=None)
|
||||
tool_tokens = max(0, int(fixed_context_tokens or 0) - int(system_prompt_tokens or 0))
|
||||
return {
|
||||
"token_count": int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None,
|
||||
"fixed_context_tokens": int(fixed_context_tokens) if isinstance(fixed_context_tokens, (int, float)) and fixed_context_tokens >= 0 else None,
|
||||
"system_prompt_tokens": int(system_prompt_tokens) if isinstance(system_prompt_tokens, (int, float)) and system_prompt_tokens >= 0 else None,
|
||||
"tool_tokens": tool_tokens,
|
||||
"message_count": len(message_list),
|
||||
"tool_count": len(tools) if isinstance(tools, list) else 0,
|
||||
"tool_names": self._agent_tool_names(tools),
|
||||
"system_prompt_chars": len(prompt),
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
return {}
|
||||
|
||||
def _estimate_context_tokens(self, agent: Any, messages: Any, system_message: Any = None) -> int | None:
|
||||
token_count = self._estimate_context_info(agent, messages, system_message).get("token_count")
|
||||
return int(token_count) if isinstance(token_count, (int, float)) and token_count > 0 else None
|
||||
|
||||
def _bridge_context_ready_event(self, session: AgentSession, instructions: str | None, profile: str | None) -> dict[str, Any]:
|
||||
info = self._estimate_context_info(session.agent, [], instructions)
|
||||
event = {
|
||||
"event": "bridge.context.ready",
|
||||
"session_id": session.session_id,
|
||||
"profile": profile or session.config.get("profile") or "default",
|
||||
"model": session.config.get("model"),
|
||||
"provider": session.config.get("provider"),
|
||||
**info,
|
||||
}
|
||||
session.config["context_info"] = event
|
||||
return event
|
||||
|
||||
def estimate_context(
|
||||
self,
|
||||
@@ -703,24 +752,23 @@ class AgentPool:
|
||||
provider: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
session = self.get_or_create(session_id, profile=profile, model=model, provider=provider)
|
||||
token_count = self._estimate_context_tokens(session.agent, messages or [], instructions)
|
||||
tools = getattr(session.agent, "tools", None) or []
|
||||
prompt = str(getattr(session.agent, "_cached_system_prompt", "") or "")
|
||||
context_info = self._estimate_context_info(session.agent, messages or [], instructions)
|
||||
print(
|
||||
"[hermes_bridge] context estimate "
|
||||
f"session={session_id} profile={profile or 'default'} "
|
||||
f"messages={len(messages or [])} system_prompt_chars={len(prompt)} "
|
||||
f"tools={len(tools) if isinstance(tools, list) else 0} "
|
||||
f"tokens={token_count if token_count is not None else 'unknown'}",
|
||||
f"messages={len(messages or [])} system_prompt_chars={context_info.get('system_prompt_chars') or 0} "
|
||||
f"tools={context_info.get('tool_count') or 0} "
|
||||
f"fixed_tokens={context_info.get('fixed_context_tokens') if context_info.get('fixed_context_tokens') is not None else 'unknown'} "
|
||||
f"tokens={context_info.get('token_count') if context_info.get('token_count') is not None else 'unknown'}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"token_count": token_count,
|
||||
"message_count": len(messages or []),
|
||||
"tool_count": len(tools) if isinstance(tools, list) else 0,
|
||||
"system_prompt_chars": len(prompt),
|
||||
"profile": profile or session.config.get("profile") or "default",
|
||||
"model": session.config.get("model"),
|
||||
"provider": session.config.get("provider"),
|
||||
**context_info,
|
||||
}
|
||||
|
||||
def respond_compression(
|
||||
@@ -1062,6 +1110,9 @@ class AgentPool:
|
||||
session.running = True
|
||||
session.current_run_id = run_id
|
||||
session.last_used_at = time.time()
|
||||
context_event = self._bridge_context_ready_event(session, instructions, profile)
|
||||
if context_event:
|
||||
record.events.append(_jsonable(context_event))
|
||||
|
||||
thread = threading.Thread(
|
||||
target=self._run_chat,
|
||||
|
||||
@@ -13,7 +13,7 @@ import { getModelContextLength } from '../model-context'
|
||||
import { readConfigYamlForProfile } from '../../config-helpers'
|
||||
import { logger } from '../../logger'
|
||||
import { bridgeLogger } from '../../logger'
|
||||
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||
import { calcAndUpdateUsage, estimateUsageTokensFromMessages, updateMessageContextTokenUsage } from './usage'
|
||||
import { isAssistantMessageSendable } from './message-format'
|
||||
import type { ChatMessage, CompressionConfig as CompressorConfig } from '../../../lib/context-compressor'
|
||||
import type { SessionState, BridgeCompressionResult } from './types'
|
||||
@@ -69,6 +69,23 @@ function buildSnapshotHistory(
|
||||
]
|
||||
}
|
||||
|
||||
export async function buildSnapshotAwareHistory(
|
||||
sessionId: string,
|
||||
profile: string,
|
||||
history: ChatMessage[],
|
||||
modelContext: { model?: string | null; provider?: string | null } = {},
|
||||
): Promise<ChatMessage[]> {
|
||||
const snapshot = getCompressionSnapshot(sessionId)
|
||||
if (!snapshot) return history
|
||||
const contextLength = getModelContextLength({
|
||||
profile,
|
||||
model: modelContext.model,
|
||||
provider: modelContext.provider,
|
||||
})
|
||||
const compressionConfig = await getRunChatCompressionConfig(profile, contextLength)
|
||||
return buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
}
|
||||
|
||||
function clampRatio(value: unknown, fallback: number, min: number, max: number): number {
|
||||
const n = typeof value === 'number' && Number.isFinite(value) ? value : fallback
|
||||
return Math.min(max, Math.max(min, n))
|
||||
@@ -343,21 +360,26 @@ export async function compressHistory(
|
||||
provider: modelContext.provider || session?.provider,
|
||||
})
|
||||
const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
||||
const compressedMeta = {
|
||||
const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens
|
||||
const compressedMeta: any = {
|
||||
event: 'compression.completed' as const,
|
||||
compressed: result.meta.compressed,
|
||||
llmCompressed: result.meta.llmCompressed,
|
||||
totalMessages: result.meta.totalMessages,
|
||||
resultMessages: result.messages.length,
|
||||
beforeTokens: totalTokens,
|
||||
afterTokens: afterTokens.inputTokens + afterTokens.outputTokens,
|
||||
afterTokens: compressedAfterTokens,
|
||||
summaryTokens: result.meta.summaryTokenEstimate,
|
||||
verbatimCount: result.meta.verbatimCount,
|
||||
compressedStartIndex: result.meta.compressedStartIndex,
|
||||
}
|
||||
replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta)
|
||||
logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)',
|
||||
sessionId, result.messages.length, afterTokens.inputTokens + afterTokens.outputTokens, totalTokens)
|
||||
sessionId, result.messages.length, compressedAfterTokens, totalTokens)
|
||||
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedAfterTokens, afterTokens)
|
||||
if (compressedContextTokens != null) {
|
||||
compressedMeta.contextTokens = compressedContextTokens
|
||||
}
|
||||
emit('compression.completed', compressedMeta)
|
||||
|
||||
const compressed = result.messages.map(m => {
|
||||
|
||||
@@ -8,10 +8,17 @@ import { getSystemPrompt } from '../../../lib/llm-prompt'
|
||||
import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store'
|
||||
import { updateUsage } from '../../../db/hermes/usage-store'
|
||||
import { logger, bridgeLogger } from '../../logger'
|
||||
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
|
||||
import { AgentBridgeClient, type AgentBridgeContextEstimate, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
|
||||
import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks'
|
||||
import { buildCompressedHistory, buildDbHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression'
|
||||
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||
import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, forceCompressBridgeHistory, pushState, replaceState } from './compression'
|
||||
import {
|
||||
calcAndUpdateUsage,
|
||||
contextTokensWithCachedOverhead,
|
||||
estimateUsageTokensFromMessages,
|
||||
getCachedBridgeContextOverhead,
|
||||
updateContextTokenUsage,
|
||||
updateMessageContextTokenUsage,
|
||||
} from './usage'
|
||||
import {
|
||||
flushBridgePendingToDb,
|
||||
ensureOpenBridgeAssistantMessage,
|
||||
@@ -62,6 +69,28 @@ export function bridgeTerminalError(chunk: Pick<AgentBridgeOutput, 'status' | 'e
|
||||
return null
|
||||
}
|
||||
|
||||
function finiteToken(value: unknown): number | undefined {
|
||||
return typeof value === 'number' && Number.isFinite(value) && value >= 0
|
||||
? Math.floor(value)
|
||||
: undefined
|
||||
}
|
||||
|
||||
function cacheBridgeContext(state: SessionState, data: Record<string, unknown> | AgentBridgeContextEstimate) {
|
||||
const fixedContextTokens = finiteToken(data.fixed_context_tokens)
|
||||
if (fixedContextTokens == null) return
|
||||
state.bridgeContext = {
|
||||
fixedContextTokens,
|
||||
systemPromptTokens: finiteToken(data.system_prompt_tokens),
|
||||
toolTokens: finiteToken(data.tool_tokens),
|
||||
systemPromptChars: finiteToken(data.system_prompt_chars),
|
||||
toolCount: finiteToken(data.tool_count),
|
||||
toolNames: Array.isArray(data.tool_names) ? data.tool_names.map(String) : undefined,
|
||||
profile: typeof data.profile === 'string' ? data.profile : state.bridgeContext?.profile,
|
||||
model: typeof data.model === 'string' ? data.model : state.bridgeContext?.model,
|
||||
provider: typeof data.provider === 'string' ? data.provider : state.bridgeContext?.provider,
|
||||
}
|
||||
}
|
||||
|
||||
export async function handleBridgeRun(
|
||||
nsp: ReturnType<Server['of']>,
|
||||
socket: Socket,
|
||||
@@ -168,6 +197,11 @@ export async function handleBridgeRun(
|
||||
sessionMap,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
async (messages) => {
|
||||
const cachedOverhead = getCachedBridgeContextOverhead(state)
|
||||
if (cachedOverhead != null) {
|
||||
const messageUsage = estimateUsageTokensFromMessages(messages)
|
||||
return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens
|
||||
}
|
||||
const estimate = await bridge.contextEstimate(
|
||||
session_id,
|
||||
messages,
|
||||
@@ -175,6 +209,7 @@ export async function handleBridgeRun(
|
||||
profile,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
)
|
||||
cacheBridgeContext(state, estimate)
|
||||
bridgeLogger.info({
|
||||
sessionId: session_id,
|
||||
profile,
|
||||
@@ -183,6 +218,7 @@ export async function handleBridgeRun(
|
||||
messages: estimate.message_count,
|
||||
toolCount: estimate.tool_count,
|
||||
systemPromptChars: estimate.system_prompt_chars,
|
||||
fixedContextTokens: estimate.fixed_context_tokens,
|
||||
fullContextTokens: estimate.token_count,
|
||||
}, '[chat-run-socket] full context estimate')
|
||||
return estimate.token_count
|
||||
@@ -308,7 +344,35 @@ async function refreshFinalContextUsage(args: {
|
||||
bridge: AgentBridgeClient
|
||||
}): Promise<number | undefined> {
|
||||
try {
|
||||
const finalHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
|
||||
const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
|
||||
const finalHistory = await buildSnapshotAwareHistory(
|
||||
args.sessionId,
|
||||
args.profile,
|
||||
dbHistory,
|
||||
{ model: args.model, provider: args.provider },
|
||||
)
|
||||
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
|
||||
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
|
||||
if (getCachedBridgeContextOverhead(args.state) != null) {
|
||||
const contextTokens = updateMessageContextTokenUsage(
|
||||
args.sessionId,
|
||||
args.state,
|
||||
args.emit,
|
||||
finalMessageTokens,
|
||||
args.usage,
|
||||
)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
messages: finalHistory.length,
|
||||
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
|
||||
messageTokens: finalMessageTokens,
|
||||
fullContextTokens: contextTokens,
|
||||
}, '[chat-run-socket] final cached context estimate')
|
||||
return contextTokens
|
||||
}
|
||||
const estimate = await args.bridge.contextEstimate(
|
||||
args.sessionId,
|
||||
finalHistory,
|
||||
@@ -316,18 +380,13 @@ async function refreshFinalContextUsage(args: {
|
||||
args.profile,
|
||||
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
|
||||
)
|
||||
cacheBridgeContext(args.state, estimate)
|
||||
const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0
|
||||
? Math.floor(estimate.token_count)
|
||||
: undefined
|
||||
if (contextTokens == null) return args.state.contextTokens
|
||||
|
||||
args.state.contextTokens = contextTokens
|
||||
args.emit('usage.updated', {
|
||||
event: 'usage.updated',
|
||||
inputTokens: args.usage.inputTokens,
|
||||
outputTokens: args.usage.outputTokens,
|
||||
contextTokens,
|
||||
})
|
||||
updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
@@ -378,7 +437,17 @@ async function applyBridgeChunkAsync(
|
||||
|
||||
for (const ev of chunk.events || []) {
|
||||
const evType = ev.event as string | undefined
|
||||
if (evType === 'tool.started') {
|
||||
if (evType === 'bridge.context.ready') {
|
||||
cacheBridgeContext(state, ev)
|
||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||
updateMessageContextTokenUsage(
|
||||
sessionId,
|
||||
state,
|
||||
emit,
|
||||
usage.inputTokens + usage.outputTokens,
|
||||
usage,
|
||||
)
|
||||
} else if (evType === 'tool.started') {
|
||||
flushBridgePendingToDb(state, sessionId, runMarker)
|
||||
const toolName = (ev.tool_name as string) || ''
|
||||
const args = ev.args as Record<string, unknown> | undefined
|
||||
@@ -498,6 +567,11 @@ async function applyBridgeChunkAsync(
|
||||
const compressionResult = ev.request_id
|
||||
? state.bridgeCompressionResults?.[String(ev.request_id)]
|
||||
: undefined
|
||||
const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens)
|
||||
const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
|
||||
const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null
|
||||
? contextTokensWithCachedOverhead(state, messageAfterTokens)
|
||||
: bridgeAfterContextTokens ?? messageAfterTokens
|
||||
const payload = {
|
||||
event: 'compression.completed',
|
||||
run_id: chunk.run_id,
|
||||
@@ -507,9 +581,8 @@ async function applyBridgeChunkAsync(
|
||||
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
|
||||
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
|
||||
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
|
||||
afterTokens: typeof ev.result_approx_tokens === 'number' && Number.isFinite(ev.result_approx_tokens) && ev.result_approx_tokens > 0
|
||||
? ev.result_approx_tokens
|
||||
: compressionResult?.afterTokens,
|
||||
afterTokens: messageAfterTokens ?? bridgeAfterContextTokens,
|
||||
contextTokens: afterContextTokens,
|
||||
summaryTokens: compressionResult?.summaryTokens,
|
||||
verbatimCount: compressionResult?.verbatimCount,
|
||||
compressedStartIndex: compressionResult?.compressedStartIndex,
|
||||
@@ -520,7 +593,12 @@ async function applyBridgeChunkAsync(
|
||||
}
|
||||
replaceState(sessionMap, sessionId, 'compression.completed', payload)
|
||||
emit('compression.completed', payload)
|
||||
await calcAndUpdateUsage(sessionId, state, emit)
|
||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||
if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) {
|
||||
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage)
|
||||
} else {
|
||||
updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage)
|
||||
}
|
||||
} else if (evType === 'bridge.compression.failed') {
|
||||
const payload = {
|
||||
event: 'compression.completed',
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { AgentBridgeClient } from '../agent-bridge'
|
||||
import { flushBridgePendingToDb } from './bridge-message'
|
||||
import { buildDbHistory, estimateSnapshotAwareHistoryUsage, forceCompressBridgeHistory, getOrCreateSession, replaceState } from './compression'
|
||||
import { handleAbort } from './abort'
|
||||
import { calcAndUpdateUsage } from './usage'
|
||||
import { calcAndUpdateUsage, contextTokensWithCachedOverhead, updateMessageContextTokenUsage } from './usage'
|
||||
import type { ContentBlock, QueuedRun, SessionState } from './types'
|
||||
|
||||
type CommandName =
|
||||
@@ -232,10 +232,11 @@ export async function handleSessionCommand(
|
||||
try {
|
||||
const history = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||
const usageEstimate = estimateSnapshotAwareHistoryUsage(sessionId, history)
|
||||
const beforeContextTokens = contextTokensWithCachedOverhead(state, usageEstimate.tokenCount)
|
||||
emit('compression.started', {
|
||||
event: 'compression.started',
|
||||
message_count: usageEstimate.messageCount,
|
||||
token_count: usageEstimate.tokenCount,
|
||||
token_count: beforeContextTokens,
|
||||
source: 'command',
|
||||
})
|
||||
const result = await forceCompressBridgeHistory(
|
||||
@@ -244,27 +245,32 @@ export async function handleSessionCommand(
|
||||
[],
|
||||
)
|
||||
state.bridgeCompressionResults = state.bridgeCompressionResults || {}
|
||||
await calcAndUpdateUsage(sessionId, state, emit)
|
||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||
const afterContextTokens = contextTokensWithCachedOverhead(state, result.afterTokens)
|
||||
emit('compression.completed', {
|
||||
event: 'compression.completed',
|
||||
compressed: result.compressed,
|
||||
llmCompressed: result.llmCompressed,
|
||||
totalMessages: result.beforeMessages,
|
||||
resultMessages: result.resultMessages,
|
||||
beforeTokens: result.beforeTokens,
|
||||
beforeTokens: beforeContextTokens,
|
||||
afterTokens: result.afterTokens,
|
||||
summaryTokens: result.summaryTokens,
|
||||
verbatimCount: result.verbatimCount,
|
||||
compressedStartIndex: result.compressedStartIndex,
|
||||
contextTokens: afterContextTokens,
|
||||
source: 'command',
|
||||
})
|
||||
updateMessageContextTokenUsage(sessionId, state, emit, result.afterTokens, usage)
|
||||
emitCommand({
|
||||
action: 'compress',
|
||||
message: `Compression completed: ${result.beforeMessages} -> ${result.resultMessages} messages, ${result.beforeTokens} -> ${result.afterTokens} tokens.`,
|
||||
message: `Compression completed: ${result.beforeMessages} -> ${result.resultMessages} messages, ${beforeContextTokens} -> ${afterContextTokens} tokens.`,
|
||||
beforeMessages: result.beforeMessages,
|
||||
resultMessages: result.resultMessages,
|
||||
beforeTokens: result.beforeTokens,
|
||||
afterTokens: result.afterTokens,
|
||||
beforeTokens: beforeContextTokens,
|
||||
afterTokens: afterContextTokens,
|
||||
messageBeforeTokens: result.beforeTokens,
|
||||
messageAfterTokens: result.afterTokens,
|
||||
compressed: result.compressed,
|
||||
})
|
||||
} catch (err) {
|
||||
|
||||
@@ -47,6 +47,7 @@ export interface SessionState {
|
||||
inputTokens?: number
|
||||
outputTokens?: number
|
||||
contextTokens?: number
|
||||
bridgeContext?: BridgeContextState
|
||||
isAborting?: boolean
|
||||
queue: QueuedRun[]
|
||||
responseRun?: ResponseRunState
|
||||
@@ -72,6 +73,18 @@ export interface ResponseRunState {
|
||||
toolCalls: Map<string, any>
|
||||
}
|
||||
|
||||
export interface BridgeContextState {
|
||||
fixedContextTokens?: number
|
||||
systemPromptTokens?: number
|
||||
toolTokens?: number
|
||||
systemPromptChars?: number
|
||||
toolCount?: number
|
||||
toolNames?: string[]
|
||||
profile?: string
|
||||
model?: string
|
||||
provider?: string
|
||||
}
|
||||
|
||||
export type ChatRunSource = 'api_server' | 'cli'
|
||||
|
||||
export interface BridgeCompressionResult {
|
||||
|
||||
@@ -78,3 +78,60 @@ export async function calcAndUpdateUsage(
|
||||
return { inputTokens: 0, outputTokens: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
export function updateContextTokenUsage(
|
||||
sid: string,
|
||||
state: SessionState,
|
||||
emit: (event: string, payload: any) => void,
|
||||
contextTokens: number | null | undefined,
|
||||
usage?: { inputTokens: number; outputTokens: number },
|
||||
): number | undefined {
|
||||
if (typeof contextTokens !== 'number' || !Number.isFinite(contextTokens) || contextTokens < 0) {
|
||||
return state.contextTokens
|
||||
}
|
||||
const normalizedContextTokens = Math.floor(contextTokens)
|
||||
state.contextTokens = normalizedContextTokens
|
||||
emit('usage.updated', {
|
||||
event: 'usage.updated',
|
||||
session_id: sid,
|
||||
inputTokens: usage?.inputTokens ?? state.inputTokens ?? 0,
|
||||
outputTokens: usage?.outputTokens ?? state.outputTokens ?? 0,
|
||||
contextTokens: normalizedContextTokens,
|
||||
})
|
||||
return normalizedContextTokens
|
||||
}
|
||||
|
||||
export function getCachedBridgeContextOverhead(state: SessionState): number | undefined {
|
||||
const fixedContextTokens = state.bridgeContext?.fixedContextTokens
|
||||
if (typeof fixedContextTokens !== 'number' || !Number.isFinite(fixedContextTokens) || fixedContextTokens < 0) {
|
||||
return undefined
|
||||
}
|
||||
return Math.floor(fixedContextTokens)
|
||||
}
|
||||
|
||||
export function contextTokensWithCachedOverhead(state: SessionState, messageTokens: number): number {
|
||||
const normalizedMessageTokens = Math.max(0, Math.floor(messageTokens))
|
||||
const fixedContextTokens = getCachedBridgeContextOverhead(state)
|
||||
return fixedContextTokens == null
|
||||
? normalizedMessageTokens
|
||||
: fixedContextTokens + normalizedMessageTokens
|
||||
}
|
||||
|
||||
export function updateMessageContextTokenUsage(
|
||||
sid: string,
|
||||
state: SessionState,
|
||||
emit: (event: string, payload: any) => void,
|
||||
messageTokens: number | null | undefined,
|
||||
usage?: { inputTokens: number; outputTokens: number },
|
||||
): number | undefined {
|
||||
if (typeof messageTokens !== 'number' || !Number.isFinite(messageTokens) || messageTokens < 0) {
|
||||
return state.contextTokens
|
||||
}
|
||||
return updateContextTokenUsage(
|
||||
sid,
|
||||
state,
|
||||
emit,
|
||||
contextTokensWithCachedOverhead(state, messageTokens),
|
||||
usage,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -9,11 +9,26 @@ const updateSessionStatsMock = vi.fn()
|
||||
const updateUsageMock = vi.fn()
|
||||
const buildCompressedHistoryMock = vi.fn()
|
||||
const buildDbHistoryMock = vi.fn()
|
||||
const buildSnapshotAwareHistoryMock = vi.fn(async (_sessionId: string, _profile: string, history: any[]) => history)
|
||||
const pushStateMock = vi.fn()
|
||||
const replaceStateMock = vi.fn()
|
||||
const forceCompressBridgeHistoryMock = vi.fn()
|
||||
const calcAndUpdateUsageMock = vi.fn()
|
||||
const estimateUsageTokensFromMessagesMock = vi.fn()
|
||||
const updateContextTokenUsageMock = vi.fn((sid: string, state: any, emit: any, contextTokens: number, usage?: { inputTokens: number; outputTokens: number }) => {
|
||||
state.contextTokens = contextTokens
|
||||
emit('usage.updated', {
|
||||
event: 'usage.updated',
|
||||
session_id: sid,
|
||||
inputTokens: usage?.inputTokens ?? state.inputTokens ?? 0,
|
||||
outputTokens: usage?.outputTokens ?? state.outputTokens ?? 0,
|
||||
contextTokens,
|
||||
})
|
||||
return contextTokens
|
||||
})
|
||||
const getCachedBridgeContextOverheadMock = vi.fn(() => undefined)
|
||||
const contextTokensWithCachedOverheadMock = vi.fn((_state: any, messageTokens: number) => messageTokens)
|
||||
const updateMessageContextTokenUsageMock = vi.fn((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, state, emit, messageTokens, usage))
|
||||
const flushBridgePendingToDbMock = vi.fn()
|
||||
const ensureOpenBridgeAssistantMessageMock = vi.fn()
|
||||
const syncBridgeReasoningToMessageMock = vi.fn()
|
||||
@@ -45,6 +60,7 @@ vi.mock('../../packages/server/src/services/logger', () => ({
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () => ({
|
||||
buildCompressedHistory: buildCompressedHistoryMock,
|
||||
buildDbHistory: buildDbHistoryMock,
|
||||
buildSnapshotAwareHistory: buildSnapshotAwareHistoryMock,
|
||||
pushState: pushStateMock,
|
||||
replaceState: replaceStateMock,
|
||||
forceCompressBridgeHistory: forceCompressBridgeHistoryMock,
|
||||
@@ -53,6 +69,10 @@ vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () =>
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({
|
||||
calcAndUpdateUsage: calcAndUpdateUsageMock,
|
||||
estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock,
|
||||
getCachedBridgeContextOverhead: getCachedBridgeContextOverheadMock,
|
||||
contextTokensWithCachedOverhead: contextTokensWithCachedOverheadMock,
|
||||
updateContextTokenUsage: updateContextTokenUsageMock,
|
||||
updateMessageContextTokenUsage: updateMessageContextTokenUsageMock,
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/bridge-message', () => ({
|
||||
@@ -103,7 +123,11 @@ describe('bridge run final context usage', () => {
|
||||
{ role: 'user', content: 'hello' },
|
||||
{ role: 'assistant', content: 'done' },
|
||||
])
|
||||
buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history)
|
||||
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 })
|
||||
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 })
|
||||
getCachedBridgeContextOverheadMock.mockReturnValue(undefined)
|
||||
contextTokensWithCachedOverheadMock.mockImplementation((_state: any, messageTokens: number) => messageTokens)
|
||||
})
|
||||
|
||||
it('refreshes full context tokens when a bridge run completes', async () => {
|
||||
@@ -161,6 +185,50 @@ describe('bridge run final context usage', () => {
|
||||
}))
|
||||
})
|
||||
|
||||
it('uses cached fixed context instead of bridge estimate when available', async () => {
|
||||
const emit = vi.fn()
|
||||
const nsp = makeNamespace(emit)
|
||||
const socket = makeSocket()
|
||||
const state = makeState()
|
||||
state.bridgeContext = { fixedContextTokens: 20_000 }
|
||||
const sessionMap = new Map([['session-1', state]])
|
||||
getCachedBridgeContextOverheadMock.mockReturnValue(20_000)
|
||||
updateMessageContextTokenUsageMock.mockImplementation((sid: string, targetState: any, targetEmit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, targetState, targetEmit, 20_000 + messageTokens, usage))
|
||||
const bridge = {
|
||||
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
|
||||
contextEstimate: vi.fn(),
|
||||
streamOutput: vi.fn(async function* () {
|
||||
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
|
||||
}),
|
||||
} as any
|
||||
|
||||
const { handleBridgeRun } = await import('../../packages/server/src/services/hermes/run-chat/handle-bridge-run')
|
||||
await handleBridgeRun(
|
||||
nsp,
|
||||
socket,
|
||||
{ input: 'hello', session_id: 'session-1' },
|
||||
'default',
|
||||
sessionMap,
|
||||
bridge,
|
||||
false,
|
||||
vi.fn(),
|
||||
vi.fn(),
|
||||
)
|
||||
|
||||
expect(bridge.contextEstimate).not.toHaveBeenCalled()
|
||||
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
|
||||
'session-1',
|
||||
state,
|
||||
expect.any(Function),
|
||||
18,
|
||||
{ inputTokens: 11, outputTokens: 7 },
|
||||
)
|
||||
expect(state.contextTokens).toBe(20_018)
|
||||
expect(emit).toHaveBeenCalledWith('run.completed', expect.objectContaining({
|
||||
contextTokens: 20_018,
|
||||
}))
|
||||
})
|
||||
|
||||
it('refreshes full context tokens when a bridge run fails', async () => {
|
||||
const emit = vi.fn()
|
||||
const nsp = makeNamespace(emit)
|
||||
|
||||
@@ -6,6 +6,17 @@ const getCompressionSnapshotMock = vi.fn()
|
||||
const getModelContextLengthMock = vi.fn()
|
||||
const calcAndUpdateUsageMock = vi.fn()
|
||||
const estimateUsageTokensFromMessagesMock = vi.fn()
|
||||
const updateMessageContextTokenUsageMock = vi.fn((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => {
|
||||
state.contextTokens = messageTokens
|
||||
emit('usage.updated', {
|
||||
event: 'usage.updated',
|
||||
session_id: sid,
|
||||
inputTokens: usage?.inputTokens ?? state.inputTokens ?? 0,
|
||||
outputTokens: usage?.outputTokens ?? state.outputTokens ?? 0,
|
||||
contextTokens: messageTokens,
|
||||
})
|
||||
return messageTokens
|
||||
})
|
||||
const compressorCompressMock = vi.fn()
|
||||
const readConfigYamlForProfileMock = vi.fn()
|
||||
const compressorConstructorMock = vi.fn()
|
||||
@@ -55,6 +66,7 @@ vi.mock('../../packages/server/src/services/logger', () => ({
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({
|
||||
calcAndUpdateUsage: calcAndUpdateUsageMock,
|
||||
estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock,
|
||||
updateMessageContextTokenUsage: updateMessageContextTokenUsageMock,
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/message-format', () => ({
|
||||
@@ -69,6 +81,7 @@ describe('run chat compression trigger', () => {
|
||||
getModelContextLengthMock.mockReset()
|
||||
calcAndUpdateUsageMock.mockReset()
|
||||
estimateUsageTokensFromMessagesMock.mockReset()
|
||||
updateMessageContextTokenUsageMock.mockClear()
|
||||
compressorCompressMock.mockReset()
|
||||
compressorConstructorMock.mockReset()
|
||||
readConfigYamlForProfileMock.mockReset()
|
||||
@@ -189,13 +202,14 @@ describe('run chat compression trigger', () => {
|
||||
},
|
||||
})
|
||||
|
||||
const emit = vi.fn()
|
||||
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
|
||||
const history = await buildCompressedHistory(
|
||||
'session-1',
|
||||
'default',
|
||||
'http://upstream',
|
||||
undefined,
|
||||
vi.fn(),
|
||||
emit,
|
||||
new Map(),
|
||||
{},
|
||||
vi.fn(async () => 120_000),
|
||||
@@ -203,6 +217,13 @@ describe('run chat compression trigger', () => {
|
||||
|
||||
expect(history).toEqual([{ role: 'user', content: 'compressed by full context estimate' }])
|
||||
expect(compressorCompressMock).toHaveBeenCalledTimes(1)
|
||||
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
|
||||
'session-1',
|
||||
expect.any(Object),
|
||||
emit,
|
||||
1_000,
|
||||
{ inputTokens: 1_000, outputTokens: 0 },
|
||||
)
|
||||
})
|
||||
|
||||
it('emits full context token usage when the full estimate is under threshold', async () => {
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { countTokens } from '../../packages/server/src/lib/context-compressor'
|
||||
import { estimateUsageTokensFromMessages } from '../../packages/server/src/services/hermes/run-chat/usage'
|
||||
import {
|
||||
contextTokensWithCachedOverhead,
|
||||
estimateUsageTokensFromMessages,
|
||||
updateMessageContextTokenUsage,
|
||||
} from '../../packages/server/src/services/hermes/run-chat/usage'
|
||||
|
||||
describe('run-chat usage token estimates', () => {
|
||||
it('counts message content instead of serialized message payloads', () => {
|
||||
@@ -30,4 +34,58 @@ describe('run-chat usage token estimates', () => {
|
||||
expect(usage.inputTokens).toBe(0)
|
||||
expect(usage.outputTokens).toBe(countTokens('calling tool') + countTokens(String(messages[0].tool_calls || '')))
|
||||
})
|
||||
|
||||
it('adds cached bridge fixed context when updating full context usage', () => {
|
||||
const emit = vi.fn()
|
||||
const state = {
|
||||
messages: [],
|
||||
isWorking: false,
|
||||
events: [],
|
||||
queue: [],
|
||||
bridgeContext: { fixedContextTokens: 20_000 },
|
||||
} as any
|
||||
|
||||
const contextTokens = updateMessageContextTokenUsage(
|
||||
'session-1',
|
||||
state,
|
||||
emit,
|
||||
1_569,
|
||||
{ inputTokens: 1_200, outputTokens: 369 },
|
||||
)
|
||||
|
||||
expect(contextTokens).toBe(21_569)
|
||||
expect(state.contextTokens).toBe(21_569)
|
||||
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
|
||||
session_id: 'session-1',
|
||||
inputTokens: 1_200,
|
||||
outputTokens: 369,
|
||||
contextTokens: 21_569,
|
||||
}))
|
||||
})
|
||||
|
||||
it('falls back to message tokens when bridge fixed context is missing', () => {
|
||||
const emit = vi.fn()
|
||||
const state = {
|
||||
messages: [],
|
||||
isWorking: false,
|
||||
events: [],
|
||||
queue: [],
|
||||
} as any
|
||||
|
||||
expect(contextTokensWithCachedOverhead(state, 1_569)).toBe(1_569)
|
||||
|
||||
const contextTokens = updateMessageContextTokenUsage(
|
||||
'session-1',
|
||||
state,
|
||||
emit,
|
||||
1_569,
|
||||
{ inputTokens: 1_200, outputTokens: 369 },
|
||||
)
|
||||
|
||||
expect(contextTokens).toBe(1_569)
|
||||
expect(state.contextTokens).toBe(1_569)
|
||||
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
|
||||
contextTokens: 1_569,
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user