Account for full context tokens in compression (#908)

* Account for full context tokens in compression

* Fix group chat final context updates

---------

Co-authored-by: Codex <codex@openai.com>
This commit is contained in:
ekko
2026-05-21 19:40:52 +08:00
committed by GitHub
parent b2ec321990
commit 39ead94352
16 changed files with 730 additions and 35 deletions
@@ -24,6 +24,17 @@ interface RunChatCompressionConfig {
compressor: Partial<CompressorConfig>
}
export class ContextWindowTooSmallError extends Error {
constructor(message: string) {
super(message)
this.name = 'ContextWindowTooSmallError'
}
}
function isContextWindowTooSmallError(err: unknown): err is ContextWindowTooSmallError {
return err instanceof ContextWindowTooSmallError || (err instanceof Error && err.name === 'ContextWindowTooSmallError')
}
function isSnapshotUsable(
snapshot: { lastMessageIndex: number } | null,
history: ChatMessage[],
@@ -167,10 +178,10 @@ export async function buildCompressedHistory(
emit: (event: string, payload: any) => void,
sessionMap: Map<string, SessionState>,
modelContext: { model?: string | null; provider?: string | null } = {},
contextTokenEstimator?: (messages: ChatMessage[]) => Promise<number | null | undefined>,
): Promise<ChatMessage[]> {
try {
let history = await buildDbHistory(sessionId, { excludeLastUser: true })
if (history.length === 0) return []
const contextLength = getModelContextLength({
profile,
@@ -185,7 +196,40 @@ export async function buildCompressedHistory(
}
const cState = getOrCreateSession(sessionMap, sessionId)
const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit)
let totalTokens = assembledTokens.inputTokens + assembledTokens.outputTokens
const estimateFullContextTokens = async (messages: ChatMessage[], fallback: number) => {
try {
const estimate = await contextTokenEstimator?.(messages)
if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate)
} catch (err) {
logger.warn(err, '[context-compress] session=%s: full context token estimate failed; using message-only estimate', sessionId)
}
return fallback
}
const emitContextUsage = (contextTokens: number) => {
cState.contextTokens = contextTokens
emit('usage.updated', {
event: 'usage.updated',
session_id: sessionId,
inputTokens: cState.inputTokens ?? assembledTokens.inputTokens,
outputTokens: cState.outputTokens ?? assembledTokens.outputTokens,
contextTokens,
})
}
const messageOnlyTotalTokens = assembledTokens.inputTokens + assembledTokens.outputTokens
let totalTokens = messageOnlyTotalTokens
if (history.length === 0) {
totalTokens = await estimateFullContextTokens([], 0)
if (totalTokens > triggerTokens) {
throw new ContextWindowTooSmallError(
`Context window is too small: system prompt and tool schemas already use ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, or disable some tools.`,
)
}
if (totalTokens > 0) emitContextUsage(totalTokens)
return []
}
const canCompressHistory = history.length > 4
const snapshot = getCompressionSnapshot(sessionId)
const staleSnapshot = snapshot && !isSnapshotUsable(snapshot, history)
if (staleSnapshot) {
@@ -193,15 +237,40 @@ export async function buildCompressedHistory(
sessionId, snapshot.lastMessageIndex, history.length)
const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
const staleUsage = estimateUsageTokensFromMessages(staleHistory)
totalTokens = staleUsage.inputTokens + staleUsage.outputTokens
totalTokens = await estimateFullContextTokens(staleHistory, staleUsage.inputTokens + staleUsage.outputTokens)
emitContextUsage(totalTokens)
logger.info({
sessionId,
profile,
messages: staleHistory.length,
messageOnlyTokens: staleUsage.inputTokens + staleUsage.outputTokens,
fullContextTokens: totalTokens,
triggerTokens,
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
snapshot: 'stale',
}, '[context-compress] threshold check')
}
if (snapshot && !staleSnapshot) {
const newMessages = history.slice(snapshot.lastMessageIndex + 1)
const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory)
totalTokens = await estimateFullContextTokens(snapshotHistory, snapshotUsage.inputTokens + snapshotUsage.outputTokens)
emitContextUsage(totalTokens)
logger.info({
sessionId,
profile,
messages: snapshotHistory.length,
messageOnlyTokens: snapshotUsage.inputTokens + snapshotUsage.outputTokens,
fullContextTokens: totalTokens,
triggerTokens,
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
snapshot: 'usable',
}, '[context-compress] threshold check')
logger.info('[context-compress] session=%s: snapshot at %d, %d new messages, assembled ~%d tokens (threshold %d)',
sessionId, snapshot.lastMessageIndex, newMessages.length, totalTokens, triggerTokens)
if (totalTokens <= triggerTokens) {
history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
history = snapshotHistory
} else {
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
}
@@ -211,7 +280,24 @@ export async function buildCompressedHistory(
} else {
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
}
} else if (history.length > 4) {
} else {
totalTokens = await estimateFullContextTokens(history, totalTokens)
emitContextUsage(totalTokens)
logger.info({
sessionId,
profile,
messages: history.length,
messageOnlyTokens: messageOnlyTotalTokens,
fullContextTokens: totalTokens,
triggerTokens,
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
snapshot: 'none',
}, '[context-compress] threshold check')
if (!canCompressHistory && totalTokens > triggerTokens) {
throw new ContextWindowTooSmallError(
`Context window is too small: fixed prompt/tool overhead plus ${history.length} history messages uses ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}, and there is not enough history to compress. Increase model context length, raise compression.threshold, or disable some tools.`,
)
}
if (totalTokens <= triggerTokens) {
logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens)
} else {
@@ -221,6 +307,7 @@ export async function buildCompressedHistory(
return history
} catch (err) {
if (isContextWindowTooSmallError(err)) throw err
logger.warn(err, '[chat-run-socket] failed to build compressed history for session %s', sessionId)
return []
}
@@ -310,6 +397,7 @@ export async function forceCompressBridgeHistory(
sessionId: string,
profile: string,
_messages: ChatMessage[],
beforeTokenOverride?: number | null,
): Promise<BridgeCompressionResult> {
const history = await buildDbHistory(sessionId, { excludeLastUser: true })
@@ -334,7 +422,9 @@ export async function forceCompressBridgeHistory(
const contextLength = getModelContextLength({ profile, model: session?.model, provider: session?.provider })
const compressionConfig = await getRunChatCompressionConfig(session?.profile || profile, contextLength)
const beforeUsage = estimateSnapshotAwareHistoryUsage(sessionId, history)
const totalTokens = beforeUsage.tokenCount
const totalTokens = typeof beforeTokenOverride === 'number' && Number.isFinite(beforeTokenOverride) && beforeTokenOverride > 0
? Math.floor(beforeTokenOverride)
: beforeUsage.tokenCount
bridgeLogger.info({
sessionId,
profile,
@@ -135,6 +135,26 @@ export async function handleBridgeRun(
emit,
sessionMap,
{ model: resolvedModel, provider: resolvedProvider },
async (messages) => {
const estimate = await bridge.contextEstimate(
session_id,
messages,
fullInstructions,
profile,
{ model: resolvedModel, provider: resolvedProvider },
)
bridgeLogger.info({
sessionId: session_id,
profile,
model: resolvedModel,
provider: resolvedProvider,
messages: estimate.message_count,
toolCount: estimate.tool_count,
systemPromptChars: estimate.system_prompt_chars,
fullContextTokens: estimate.token_count,
}, '[chat-run-socket] full context estimate')
return estimate.token_count
},
)
const bridgeHistory = history
@@ -315,9 +335,19 @@ async function applyBridgeChunkAsync(
} else if (evType === 'bridge.compression.requested') {
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
const tokenCount = bridgeHistory.length > 0
? bridgeUsage.inputTokens + bridgeUsage.outputTokens
: ev.approx_tokens
const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens
const tokenCount = typeof ev.approx_tokens === 'number' && Number.isFinite(ev.approx_tokens) && ev.approx_tokens > 0
? ev.approx_tokens
: messageOnlyTokens
bridgeLogger.info({
sessionId,
profile,
bridgeMessages: ev.message_count,
dbMessages: bridgeHistory.length,
messageOnlyTokens,
fullContextTokens: tokenCount,
source: typeof ev.approx_tokens === 'number' ? 'bridge' : 'message-only-fallback',
}, '[chat-run-socket] bridge compression token estimate')
const payload = {
event: 'compression.started',
run_id: chunk.run_id,
@@ -334,6 +364,7 @@ async function applyBridgeChunkAsync(
sessionId,
profile,
ev.messages as ChatMessage[],
typeof ev.approx_tokens === 'number' ? ev.approx_tokens : undefined,
)
state.bridgeCompressionResults = state.bridgeCompressionResults || {}
state.bridgeCompressionResults[String(ev.request_id)] = compressed
@@ -357,7 +388,9 @@ async function applyBridgeChunkAsync(
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
afterTokens: compressionResult?.afterTokens,
afterTokens: typeof ev.result_approx_tokens === 'number' && Number.isFinite(ev.result_approx_tokens) && ev.result_approx_tokens > 0
? ev.result_approx_tokens
: compressionResult?.afterTokens,
summaryTokens: compressionResult?.summaryTokens,
verbatimCount: compressionResult?.verbatimCount,
compressedStartIndex: compressionResult?.compressedStartIndex,
@@ -266,6 +266,7 @@ export class ChatRunSocket {
events: state.isWorking ? state.events : [],
inputTokens: state.inputTokens,
outputTokens: state.outputTokens,
contextTokens: state.contextTokens,
queueLength: state.queue?.length || 0,
})
@@ -46,6 +46,7 @@ export interface SessionState {
profile?: string
inputTokens?: number
outputTokens?: number
contextTokens?: number
isAborting?: boolean
queue: QueuedRun[]
responseRun?: ResponseRunState