Account for full context tokens in compression (#908)
* Account for full context tokens in compression * Fix group chat final context updates --------- Co-authored-by: Codex <codex@openai.com>
This commit is contained in:
@@ -24,6 +24,17 @@ interface RunChatCompressionConfig {
|
||||
compressor: Partial<CompressorConfig>
|
||||
}
|
||||
|
||||
export class ContextWindowTooSmallError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'ContextWindowTooSmallError'
|
||||
}
|
||||
}
|
||||
|
||||
function isContextWindowTooSmallError(err: unknown): err is ContextWindowTooSmallError {
|
||||
return err instanceof ContextWindowTooSmallError || (err instanceof Error && err.name === 'ContextWindowTooSmallError')
|
||||
}
|
||||
|
||||
function isSnapshotUsable(
|
||||
snapshot: { lastMessageIndex: number } | null,
|
||||
history: ChatMessage[],
|
||||
@@ -167,10 +178,10 @@ export async function buildCompressedHistory(
|
||||
emit: (event: string, payload: any) => void,
|
||||
sessionMap: Map<string, SessionState>,
|
||||
modelContext: { model?: string | null; provider?: string | null } = {},
|
||||
contextTokenEstimator?: (messages: ChatMessage[]) => Promise<number | null | undefined>,
|
||||
): Promise<ChatMessage[]> {
|
||||
try {
|
||||
let history = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||
if (history.length === 0) return []
|
||||
|
||||
const contextLength = getModelContextLength({
|
||||
profile,
|
||||
@@ -185,7 +196,40 @@ export async function buildCompressedHistory(
|
||||
}
|
||||
const cState = getOrCreateSession(sessionMap, sessionId)
|
||||
const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
||||
let totalTokens = assembledTokens.inputTokens + assembledTokens.outputTokens
|
||||
const estimateFullContextTokens = async (messages: ChatMessage[], fallback: number) => {
|
||||
try {
|
||||
const estimate = await contextTokenEstimator?.(messages)
|
||||
if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate)
|
||||
} catch (err) {
|
||||
logger.warn(err, '[context-compress] session=%s: full context token estimate failed; using message-only estimate', sessionId)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
const emitContextUsage = (contextTokens: number) => {
|
||||
cState.contextTokens = contextTokens
|
||||
emit('usage.updated', {
|
||||
event: 'usage.updated',
|
||||
session_id: sessionId,
|
||||
inputTokens: cState.inputTokens ?? assembledTokens.inputTokens,
|
||||
outputTokens: cState.outputTokens ?? assembledTokens.outputTokens,
|
||||
contextTokens,
|
||||
})
|
||||
}
|
||||
const messageOnlyTotalTokens = assembledTokens.inputTokens + assembledTokens.outputTokens
|
||||
let totalTokens = messageOnlyTotalTokens
|
||||
|
||||
if (history.length === 0) {
|
||||
totalTokens = await estimateFullContextTokens([], 0)
|
||||
if (totalTokens > triggerTokens) {
|
||||
throw new ContextWindowTooSmallError(
|
||||
`Context window is too small: system prompt and tool schemas already use ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, or disable some tools.`,
|
||||
)
|
||||
}
|
||||
if (totalTokens > 0) emitContextUsage(totalTokens)
|
||||
return []
|
||||
}
|
||||
|
||||
const canCompressHistory = history.length > 4
|
||||
const snapshot = getCompressionSnapshot(sessionId)
|
||||
const staleSnapshot = snapshot && !isSnapshotUsable(snapshot, history)
|
||||
if (staleSnapshot) {
|
||||
@@ -193,15 +237,40 @@ export async function buildCompressedHistory(
|
||||
sessionId, snapshot.lastMessageIndex, history.length)
|
||||
const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
const staleUsage = estimateUsageTokensFromMessages(staleHistory)
|
||||
totalTokens = staleUsage.inputTokens + staleUsage.outputTokens
|
||||
totalTokens = await estimateFullContextTokens(staleHistory, staleUsage.inputTokens + staleUsage.outputTokens)
|
||||
emitContextUsage(totalTokens)
|
||||
logger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
messages: staleHistory.length,
|
||||
messageOnlyTokens: staleUsage.inputTokens + staleUsage.outputTokens,
|
||||
fullContextTokens: totalTokens,
|
||||
triggerTokens,
|
||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||
snapshot: 'stale',
|
||||
}, '[context-compress] threshold check')
|
||||
}
|
||||
|
||||
if (snapshot && !staleSnapshot) {
|
||||
const newMessages = history.slice(snapshot.lastMessageIndex + 1)
|
||||
const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||
totalTokens = await estimateFullContextTokens(snapshotHistory, snapshotUsage.inputTokens + snapshotUsage.outputTokens)
|
||||
emitContextUsage(totalTokens)
|
||||
logger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
messages: snapshotHistory.length,
|
||||
messageOnlyTokens: snapshotUsage.inputTokens + snapshotUsage.outputTokens,
|
||||
fullContextTokens: totalTokens,
|
||||
triggerTokens,
|
||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||
snapshot: 'usable',
|
||||
}, '[context-compress] threshold check')
|
||||
logger.info('[context-compress] session=%s: snapshot at %d, %d new messages, assembled ~%d tokens (threshold %d)',
|
||||
sessionId, snapshot.lastMessageIndex, newMessages.length, totalTokens, triggerTokens)
|
||||
if (totalTokens <= triggerTokens) {
|
||||
history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
history = snapshotHistory
|
||||
} else {
|
||||
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
||||
}
|
||||
@@ -211,7 +280,24 @@ export async function buildCompressedHistory(
|
||||
} else {
|
||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
||||
}
|
||||
} else if (history.length > 4) {
|
||||
} else {
|
||||
totalTokens = await estimateFullContextTokens(history, totalTokens)
|
||||
emitContextUsage(totalTokens)
|
||||
logger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
messages: history.length,
|
||||
messageOnlyTokens: messageOnlyTotalTokens,
|
||||
fullContextTokens: totalTokens,
|
||||
triggerTokens,
|
||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||
snapshot: 'none',
|
||||
}, '[context-compress] threshold check')
|
||||
if (!canCompressHistory && totalTokens > triggerTokens) {
|
||||
throw new ContextWindowTooSmallError(
|
||||
`Context window is too small: fixed prompt/tool overhead plus ${history.length} history messages uses ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}, and there is not enough history to compress. Increase model context length, raise compression.threshold, or disable some tools.`,
|
||||
)
|
||||
}
|
||||
if (totalTokens <= triggerTokens) {
|
||||
logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens)
|
||||
} else {
|
||||
@@ -221,6 +307,7 @@ export async function buildCompressedHistory(
|
||||
|
||||
return history
|
||||
} catch (err) {
|
||||
if (isContextWindowTooSmallError(err)) throw err
|
||||
logger.warn(err, '[chat-run-socket] failed to build compressed history for session %s', sessionId)
|
||||
return []
|
||||
}
|
||||
@@ -310,6 +397,7 @@ export async function forceCompressBridgeHistory(
|
||||
sessionId: string,
|
||||
profile: string,
|
||||
_messages: ChatMessage[],
|
||||
beforeTokenOverride?: number | null,
|
||||
): Promise<BridgeCompressionResult> {
|
||||
const history = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||
|
||||
@@ -334,7 +422,9 @@ export async function forceCompressBridgeHistory(
|
||||
const contextLength = getModelContextLength({ profile, model: session?.model, provider: session?.provider })
|
||||
const compressionConfig = await getRunChatCompressionConfig(session?.profile || profile, contextLength)
|
||||
const beforeUsage = estimateSnapshotAwareHistoryUsage(sessionId, history)
|
||||
const totalTokens = beforeUsage.tokenCount
|
||||
const totalTokens = typeof beforeTokenOverride === 'number' && Number.isFinite(beforeTokenOverride) && beforeTokenOverride > 0
|
||||
? Math.floor(beforeTokenOverride)
|
||||
: beforeUsage.tokenCount
|
||||
bridgeLogger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
|
||||
@@ -135,6 +135,26 @@ export async function handleBridgeRun(
|
||||
emit,
|
||||
sessionMap,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
async (messages) => {
|
||||
const estimate = await bridge.contextEstimate(
|
||||
session_id,
|
||||
messages,
|
||||
fullInstructions,
|
||||
profile,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
)
|
||||
bridgeLogger.info({
|
||||
sessionId: session_id,
|
||||
profile,
|
||||
model: resolvedModel,
|
||||
provider: resolvedProvider,
|
||||
messages: estimate.message_count,
|
||||
toolCount: estimate.tool_count,
|
||||
systemPromptChars: estimate.system_prompt_chars,
|
||||
fullContextTokens: estimate.token_count,
|
||||
}, '[chat-run-socket] full context estimate')
|
||||
return estimate.token_count
|
||||
},
|
||||
)
|
||||
const bridgeHistory = history
|
||||
|
||||
@@ -315,9 +335,19 @@ async function applyBridgeChunkAsync(
|
||||
} else if (evType === 'bridge.compression.requested') {
|
||||
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
|
||||
const tokenCount = bridgeHistory.length > 0
|
||||
? bridgeUsage.inputTokens + bridgeUsage.outputTokens
|
||||
: ev.approx_tokens
|
||||
const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens
|
||||
const tokenCount = typeof ev.approx_tokens === 'number' && Number.isFinite(ev.approx_tokens) && ev.approx_tokens > 0
|
||||
? ev.approx_tokens
|
||||
: messageOnlyTokens
|
||||
bridgeLogger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
bridgeMessages: ev.message_count,
|
||||
dbMessages: bridgeHistory.length,
|
||||
messageOnlyTokens,
|
||||
fullContextTokens: tokenCount,
|
||||
source: typeof ev.approx_tokens === 'number' ? 'bridge' : 'message-only-fallback',
|
||||
}, '[chat-run-socket] bridge compression token estimate')
|
||||
const payload = {
|
||||
event: 'compression.started',
|
||||
run_id: chunk.run_id,
|
||||
@@ -334,6 +364,7 @@ async function applyBridgeChunkAsync(
|
||||
sessionId,
|
||||
profile,
|
||||
ev.messages as ChatMessage[],
|
||||
typeof ev.approx_tokens === 'number' ? ev.approx_tokens : undefined,
|
||||
)
|
||||
state.bridgeCompressionResults = state.bridgeCompressionResults || {}
|
||||
state.bridgeCompressionResults[String(ev.request_id)] = compressed
|
||||
@@ -357,7 +388,9 @@ async function applyBridgeChunkAsync(
|
||||
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
|
||||
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
|
||||
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
|
||||
afterTokens: compressionResult?.afterTokens,
|
||||
afterTokens: typeof ev.result_approx_tokens === 'number' && Number.isFinite(ev.result_approx_tokens) && ev.result_approx_tokens > 0
|
||||
? ev.result_approx_tokens
|
||||
: compressionResult?.afterTokens,
|
||||
summaryTokens: compressionResult?.summaryTokens,
|
||||
verbatimCount: compressionResult?.verbatimCount,
|
||||
compressedStartIndex: compressionResult?.compressedStartIndex,
|
||||
|
||||
@@ -266,6 +266,7 @@ export class ChatRunSocket {
|
||||
events: state.isWorking ? state.events : [],
|
||||
inputTokens: state.inputTokens,
|
||||
outputTokens: state.outputTokens,
|
||||
contextTokens: state.contextTokens,
|
||||
queueLength: state.queue?.length || 0,
|
||||
})
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ export interface SessionState {
|
||||
profile?: string
|
||||
inputTokens?: number
|
||||
outputTokens?: number
|
||||
contextTokens?: number
|
||||
isAborting?: boolean
|
||||
queue: QueuedRun[]
|
||||
responseRun?: ResponseRunState
|
||||
|
||||
Reference in New Issue
Block a user