fix context token resume (#1039)

This commit is contained in:
ekko
2026-05-26 16:32:07 +08:00
committed by GitHub
parent e686f0277a
commit ad1cab277a
13 changed files with 959 additions and 203 deletions
@@ -79,6 +79,7 @@ export interface SummarizerOptions {
profile?: string
model?: string | null
provider?: string | null
workerKey?: string
}
// ─── Token counting ─────────────────────────────────────
@@ -454,6 +455,7 @@ export async function callSummarizer(
const bridge = new AgentBridgeClient({ timeoutMs: timeoutMs + 15_000 })
const sessionId = `compress_${Date.now().toString(36)}_${randomUUID().replace(/-/g, '').slice(0, 12)}`
const workerKey = options.workerKey || `${profile}:compression:${sessionId}`
try {
const result = await bridge.request<AgentBridgeRunResult>({
@@ -462,6 +464,7 @@ export async function callSummarizer(
message: prompt,
conversation_history: convHistory,
profile,
worker_key: workerKey,
source: 'api_server',
wait: true,
timeout: Math.ceil(timeoutMs / 1000),
@@ -482,7 +485,7 @@ export async function callSummarizer(
if (!output) throw new Error('Empty summarization response')
return output
} finally {
await bridge.destroy(sessionId, profile).catch(() => undefined)
await bridge.destroy(sessionId, profile, workerKey).catch(() => undefined)
}
}
@@ -166,7 +166,7 @@ export class AgentBridgeClient {
private summarizePayload(payload: Record<string, unknown>): Record<string, unknown> {
const action = String(payload.action || '')
const summary: Record<string, unknown> = { action }
for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile']) {
for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile', 'worker_key']) {
if (payload[key] != null) summary[key] = payload[key]
}
if (Array.isArray(payload.conversation_history)) summary.conversation_history_count = payload.conversation_history.length
@@ -569,11 +569,12 @@ export class AgentBridgeClient {
})
}
destroy(sessionId: string, profile?: string): Promise<AgentBridgeResponse> {
destroy(sessionId: string, profile?: string, workerKey?: string): Promise<AgentBridgeResponse> {
return this.request({
action: 'destroy',
session_id: sessionId,
...(profile ? { profile } : {}),
...(workerKey ? { worker_key: workerKey } : {}),
})
}
@@ -2200,7 +2200,8 @@ class WorkerProcess:
STARTUP_TIMEOUT_SECONDS = 120
REQUEST_TIMEOUT_SECONDS = 120
def __init__(self, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None:
def __init__(self, key: str, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None:
self.key = key or profile or "default"
self.profile = profile or "default"
self.endpoint = endpoint
self.agent_root = agent_root
@@ -2263,14 +2264,14 @@ class WorkerProcess:
for line in proc.stderr:
text = line.rstrip()
if text:
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.profile}").start()
threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.key}").start()
def _wait_ready(self) -> None:
proc = self.process
if proc is None or proc.stdout is None:
raise RuntimeError(f"profile worker {self.profile} did not start")
raise RuntimeError(f"profile worker {self.key} did not start")
lines: queue.Queue[str | None] = queue.Queue()
ready_event = threading.Event()
@@ -2281,17 +2282,17 @@ class WorkerProcess:
if ready_event.is_set():
text = line.rstrip()
if text:
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
else:
lines.put(line)
finally:
lines.put(None)
threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.profile}").start()
threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.key}").start()
deadline = time.time() + self.STARTUP_TIMEOUT_SECONDS
while time.time() < deadline:
if proc.poll() is not None:
raise RuntimeError(f"profile worker {self.profile} exited before ready")
raise RuntimeError(f"profile worker {self.key} exited before ready")
try:
line = lines.get(timeout=0.1)
except queue.Empty:
@@ -2301,7 +2302,7 @@ class WorkerProcess:
continue
text = line.strip()
if text:
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
try:
data = json.loads(text)
if data.get("event") == "ready":
@@ -2310,7 +2311,7 @@ class WorkerProcess:
except Exception:
pass
self.stop()
raise RuntimeError(f"profile worker {self.profile} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s")
raise RuntimeError(f"profile worker {self.key} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s")
def stop(self) -> None:
with self._lock:
@@ -2337,8 +2338,8 @@ class WorkerProcess:
return _send_bridge_request(self.endpoint, req, self.REQUEST_TIMEOUT_SECONDS)
def _worker_endpoint(profile: str) -> str:
safe = hashlib.sha256(profile.encode("utf-8")).hexdigest()[:16]
def _worker_endpoint(key: str) -> str:
safe = hashlib.sha256(key.encode("utf-8")).hexdigest()[:16]
if os.name == "nt":
port_base = int(os.environ.get("HERMES_AGENT_BRIDGE_WORKER_PORT_BASE", "18780"))
return f"tcp://127.0.0.1:{port_base + int(safe[:4], 16) % 1000}"
@@ -2533,11 +2534,17 @@ class BridgeBroker:
self.hermes_home = hermes_home
self._workers: dict[str, WorkerProcess] = {}
self._run_profile: dict[str, str] = {}
self._run_worker_key: dict[str, str] = {}
self._running_run_profile: dict[str, str] = {}
self._running_run_worker_key: dict[str, str] = {}
self._session_profile: dict[str, str] = {}
self._session_worker_key: dict[str, str] = {}
self._approval_profile: dict[str, str] = {}
self._approval_worker_key: dict[str, str] = {}
self._clarify_profile: dict[str, str] = {}
self._clarify_worker_key: dict[str, str] = {}
self._compression_profile: dict[str, str] = {}
self._compression_worker_key: dict[str, str] = {}
self._lock = threading.RLock()
self._stop = threading.Event()
self._last_gc = time.time()
@@ -2546,58 +2553,73 @@ class BridgeBroker:
profile = str(value or "").strip()
return profile or "default"
def _worker_for_profile(self, profile: str) -> WorkerProcess:
def _normalize_worker_key(self, profile: str, value: Any = None) -> str:
worker_key = str(value or "").strip()
return worker_key or profile
def _worker_for_profile(self, profile: str, worker_key: str | None = None) -> WorkerProcess:
profile = self._normalize_profile(profile)
key = self._normalize_worker_key(profile, worker_key)
with self._lock:
worker = self._workers.get(profile)
worker = self._workers.get(key)
if worker is None:
worker = WorkerProcess(profile, _worker_endpoint(profile), self.agent_root, self.hermes_home)
self._workers[profile] = worker
worker = WorkerProcess(key, profile, _worker_endpoint(key), self.agent_root, self.hermes_home)
self._workers[key] = worker
return worker
def _profile_for_run(self, run_id: str) -> str:
def _route_for_run(self, run_id: str) -> tuple[str, str | None]:
with self._lock:
profile = self._run_profile.get(run_id)
worker_key = self._run_worker_key.get(run_id)
if not profile:
raise KeyError(f"unknown run: {run_id}")
return profile
return profile, worker_key
def _profile_for_session(self, session_id: str, fallback_profile: Any = None) -> str:
def _route_for_session(self, session_id: str, fallback_profile: Any = None, worker_key: Any = None) -> tuple[str, str | None]:
with self._lock:
profile = self._session_profile.get(session_id)
stored_worker_key = self._session_worker_key.get(session_id)
if not profile:
fallback = self._normalize_profile(fallback_profile)
if fallback_profile is not None and fallback:
return fallback
return fallback, self._normalize_worker_key(fallback, worker_key)
raise KeyError(f"unknown session: {session_id}")
return profile
return profile, self._normalize_worker_key(profile, worker_key) if worker_key is not None else stored_worker_key
def _record_response_routes(self, profile: str, resp: dict[str, Any]) -> None:
def _record_response_routes(self, profile: str, worker_key: str, resp: dict[str, Any]) -> None:
run_id = str(resp.get("run_id") or "")
session_id = str(resp.get("session_id") or "")
with self._lock:
if run_id:
self._run_profile[run_id] = profile
self._run_worker_key[run_id] = worker_key
if resp.get("status") == "running":
self._running_run_profile[run_id] = profile
self._running_run_worker_key[run_id] = worker_key
else:
self._running_run_profile.pop(run_id, None)
self._running_run_worker_key.pop(run_id, None)
if session_id:
self._session_profile[session_id] = profile
self._session_worker_key[session_id] = worker_key
for event in resp.get("events") or []:
if not isinstance(event, dict):
continue
approval_id = str(event.get("approval_id") or "")
if approval_id:
self._approval_profile[approval_id] = profile
self._approval_worker_key[approval_id] = worker_key
clarify_id = str(event.get("clarify_id") or "")
if clarify_id:
self._clarify_profile[clarify_id] = profile
self._clarify_worker_key[clarify_id] = worker_key
request_id = str(event.get("request_id") or "")
if event.get("event") == "bridge.compression.requested" and request_id:
self._compression_profile[request_id] = profile
self._compression_worker_key[request_id] = worker_key
if event.get("event") in {"bridge.compression.completed", "bridge.compression.failed"} and request_id:
self._compression_profile.pop(request_id, None)
self._compression_worker_key.pop(request_id, None)
def stop(self) -> None:
self._stop.set()
@@ -2605,20 +2627,29 @@ class BridgeBroker:
workers = list(self._workers.values())
self._workers.clear()
self._run_profile.clear()
self._run_worker_key.clear()
self._running_run_profile.clear()
self._running_run_worker_key.clear()
self._session_profile.clear()
self._session_worker_key.clear()
self._approval_profile.clear()
self._approval_worker_key.clear()
self._clarify_profile.clear()
self._clarify_worker_key.clear()
self._compression_profile.clear()
self._compression_worker_key.clear()
for worker in workers:
worker.stop()
def _forward(self, profile: str, req: dict[str, Any]) -> dict[str, Any]:
worker = self._worker_for_profile(profile)
def _forward(self, profile: str, req: dict[str, Any], worker_key: str | None = None) -> dict[str, Any]:
profile = self._normalize_profile(profile)
key = self._normalize_worker_key(profile, worker_key)
worker = self._worker_for_profile(profile, key)
forwarded = dict(req)
forwarded["profile"] = profile
forwarded.pop("worker_key", None)
resp = worker.request(forwarded)
self._record_response_routes(profile, resp)
self._record_response_routes(profile, key, resp)
return resp
def handle(self, req: dict[str, Any]) -> dict[str, Any]:
@@ -2629,15 +2660,16 @@ class BridgeBroker:
if action == "ping":
with self._lock:
worker_details = {
profile: {
key: {
"running": worker.running,
"pid": worker.pid,
"endpoint": worker.endpoint,
"profile": getattr(worker, "profile", key),
"last_used_at": worker.last_used_at,
}
for profile, worker in self._workers.items()
for key, worker in self._workers.items()
}
workers = {profile: details["running"] for profile, details in worker_details.items()}
workers = {key: details["running"] for key, details in worker_details.items()}
sessions_by_profile: dict[str, int] = {}
for profile in self._session_profile.values():
sessions_by_profile[profile] = sessions_by_profile.get(profile, 0) + 1
@@ -2664,29 +2696,32 @@ class BridgeBroker:
if action == "worker_ping":
profile = self._normalize_profile(req.get("profile"))
resp = self._forward(profile, {"action": "ping"})
worker_key = self._normalize_worker_key(profile, req.get("worker_key"))
resp = self._forward(profile, {"action": "ping"}, worker_key)
resp["worker_profile"] = profile
resp["worker_key"] = worker_key
return resp
if action == "chat":
profile = self._normalize_profile(req.get("profile"))
return self._forward(profile, req)
return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
if action == "context_estimate":
profile = self._normalize_profile(req.get("profile"))
return self._forward(profile, req)
return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
if action in {"get_result", "get_output"}:
profile = self._profile_for_run(str(req.get("run_id") or ""))
return self._forward(profile, req)
profile, worker_key = self._route_for_run(str(req.get("run_id") or ""))
return self._forward(profile, req, worker_key)
if action in {"interrupt", "steer", "command", "goal_evaluate", "goal_pause", "status", "get_history", "destroy"}:
session_id = str(req.get("session_id") or "")
profile = self._profile_for_session(session_id, req.get("profile"))
resp = self._forward(profile, req)
profile, worker_key = self._route_for_session(session_id, req.get("profile"), req.get("worker_key") if "worker_key" in req else None)
resp = self._forward(profile, req, worker_key)
if action == "destroy":
with self._lock:
self._session_profile.pop(session_id, None)
self._session_worker_key.pop(session_id, None)
return resp
if action == "approval_respond":
@@ -2695,9 +2730,10 @@ class BridgeBroker:
raise ValueError("approval_id is required")
with self._lock:
profile = self._approval_profile.get(approval_id)
worker_key = self._approval_worker_key.get(approval_id)
if not profile:
raise KeyError(f"unknown approval request: {approval_id}")
return self._forward(profile, req)
return self._forward(profile, req, worker_key)
if action == "clarify_respond":
clarify_id = str(req.get("clarify_id") or "").strip()
@@ -2705,9 +2741,10 @@ class BridgeBroker:
raise ValueError("clarify_id is required")
with self._lock:
profile = self._clarify_profile.get(clarify_id)
worker_key = self._clarify_worker_key.get(clarify_id)
if not profile:
raise KeyError(f"unknown clarify request: {clarify_id}")
return self._forward(profile, req)
return self._forward(profile, req, worker_key)
if action == "compression_respond":
request_id = str(req.get("request_id") or "").strip()
@@ -2715,20 +2752,27 @@ class BridgeBroker:
raise ValueError("request_id is required")
with self._lock:
profile = self._compression_profile.get(request_id)
worker_key = self._compression_worker_key.get(request_id)
if not profile:
raise KeyError(f"unknown compression request: {request_id}")
return self._forward(profile, req)
return self._forward(profile, req, worker_key)
if action == "destroy_all":
with self._lock:
workers = list(self._workers.values())
self._workers.clear()
self._run_profile.clear()
self._run_worker_key.clear()
self._running_run_profile.clear()
self._running_run_worker_key.clear()
self._session_profile.clear()
self._session_worker_key.clear()
self._approval_profile.clear()
self._approval_worker_key.clear()
self._clarify_profile.clear()
self._clarify_worker_key.clear()
self._compression_profile.clear()
self._compression_worker_key.clear()
destroyed = 0
for worker in workers:
try:
@@ -2744,40 +2788,56 @@ class BridgeBroker:
if action == "destroy_profile":
profile = self._normalize_profile(req.get("profile"))
with self._lock:
worker = self._workers.pop(profile, None)
workers = [
worker
for key, worker in list(self._workers.items())
if getattr(worker, "profile", key) == profile
]
for worker in workers:
self._workers.pop(worker.key, None)
self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile}
self._run_worker_key = {key: value for key, value in self._run_worker_key.items() if key in self._run_profile}
self._running_run_profile = {key: value for key, value in self._running_run_profile.items() if value != profile}
self._running_run_worker_key = {key: value for key, value in self._running_run_worker_key.items() if key in self._running_run_profile}
self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile}
self._session_worker_key = {key: value for key, value in self._session_worker_key.items() if key in self._session_profile}
self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile}
self._approval_worker_key = {key: value for key, value in self._approval_worker_key.items() if key in self._approval_profile}
self._clarify_profile = {key: value for key, value in self._clarify_profile.items() if value != profile}
self._clarify_worker_key = {key: value for key, value in self._clarify_worker_key.items() if key in self._clarify_profile}
self._compression_profile = {key: value for key, value in self._compression_profile.items() if value != profile}
self._compression_worker_key = {key: value for key, value in self._compression_worker_key.items() if key in self._compression_profile}
if worker is None or not worker.running:
if worker is not None:
worker.stop()
if not workers:
return {"profile": profile, "destroyed": 0}
try:
resp = worker.request({"action": "destroy_all"})
destroyed = int(resp.get("destroyed") or 0)
except Exception:
destroyed = 0
finally:
worker.stop()
destroyed = 0
for worker in workers:
if not worker.running:
worker.stop()
continue
try:
resp = worker.request({"action": "destroy_all"})
destroyed += int(resp.get("destroyed") or 0)
except Exception:
pass
finally:
worker.stop()
return {"profile": profile, "destroyed": destroyed}
if action == "list":
sessions: list[Any] = []
with self._lock:
workers = list(self._workers.items())
for profile, worker in workers:
for key, worker in workers:
if not worker.running:
continue
try:
resp = worker.request({"action": "list"})
for session in resp.get("sessions") or []:
if isinstance(session, dict):
session.setdefault("profile", profile)
session.setdefault("profile", getattr(worker, "profile", key))
session.setdefault("worker_key", key)
sessions.append(session)
except Exception:
pass
@@ -2826,12 +2886,12 @@ class BridgeBroker:
self._last_gc = now
with self._lock:
idle = [
profile for profile, worker in self._workers.items()
key for key, worker in self._workers.items()
if worker.running and now - worker.last_used_at > self.IDLE_TIMEOUT_SECONDS
]
for profile in idle:
for key in idle:
with self._lock:
worker = self._workers.pop(profile, None)
worker = self._workers.pop(key, None)
if worker:
worker.stop()
@@ -195,7 +195,8 @@ export async function buildCompressedHistory(
emit: (event: string, payload: any) => void,
sessionMap: Map<string, SessionState>,
modelContext: { model?: string | null; provider?: string | null } = {},
contextTokenEstimator?: (messages: ChatMessage[]) => Promise<number | null | undefined>,
contextTokenEstimator?: (messages: ChatMessage[], messageTokens: number) => Promise<number | null | undefined>,
currentInputTokens = 0,
): Promise<ChatMessage[]> {
try {
let history = await buildDbHistory(sessionId, { excludeLastUser: true })
@@ -213,14 +214,18 @@ export async function buildCompressedHistory(
}
const cState = getOrCreateSession(sessionMap, sessionId)
const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit)
const estimateFullContextTokens = async (messages: ChatMessage[], fallback: number) => {
const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
? Math.floor(currentInputTokens)
: 0
const estimateLocalContextTokens = async (messages: ChatMessage[], messageTokens: number) => {
const localMessageTokens = Math.max(0, Math.floor(messageTokens))
try {
const estimate = await contextTokenEstimator?.(messages)
const estimate = await contextTokenEstimator?.(messages, localMessageTokens)
if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate)
} catch (err) {
logger.warn(err, '[context-compress] session=%s: full context token estimate failed; using message-only estimate', sessionId)
logger.warn(err, '[context-compress] session=%s: fixed context token estimate failed; using message-only estimate', sessionId)
}
return fallback
return localMessageTokens
}
const emitContextUsage = (contextTokens: number) => {
cState.contextTokens = contextTokens
@@ -236,10 +241,10 @@ export async function buildCompressedHistory(
let totalTokens = messageOnlyTotalTokens
if (history.length === 0) {
totalTokens = await estimateFullContextTokens([], 0)
totalTokens = await estimateLocalContextTokens([], Math.max(currentRunInputTokens, messageOnlyTotalTokens))
if (totalTokens > triggerTokens) {
throw new ContextWindowTooSmallError(
`Context window is too small: system prompt and tool schemas already use ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, or disable some tools.`,
`Context window is too small: fixed prompt/tool overhead plus the current input uses ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, shorten the input, or disable some tools.`,
)
}
if (totalTokens > 0) emitContextUsage(totalTokens)
@@ -254,13 +259,15 @@ export async function buildCompressedHistory(
sessionId, snapshot.lastMessageIndex, history.length)
const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
const staleUsage = estimateUsageTokensFromMessages(staleHistory)
totalTokens = await estimateFullContextTokens(staleHistory, staleUsage.inputTokens + staleUsage.outputTokens)
const staleMessageTokens = staleUsage.inputTokens + staleUsage.outputTokens
const staleRunMessageTokens = Math.max(staleMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
totalTokens = await estimateLocalContextTokens(staleHistory, staleRunMessageTokens)
emitContextUsage(totalTokens)
logger.info({
sessionId,
profile,
messages: staleHistory.length,
messageOnlyTokens: staleUsage.inputTokens + staleUsage.outputTokens,
messageOnlyTokens: staleRunMessageTokens,
fullContextTokens: totalTokens,
triggerTokens,
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
@@ -272,13 +279,15 @@ export async function buildCompressedHistory(
const newMessages = history.slice(snapshot.lastMessageIndex + 1)
const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory)
totalTokens = await estimateFullContextTokens(snapshotHistory, snapshotUsage.inputTokens + snapshotUsage.outputTokens)
const snapshotMessageTokens = snapshotUsage.inputTokens + snapshotUsage.outputTokens
const snapshotRunMessageTokens = Math.max(snapshotMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
totalTokens = await estimateLocalContextTokens(snapshotHistory, snapshotRunMessageTokens)
emitContextUsage(totalTokens)
logger.info({
sessionId,
profile,
messages: snapshotHistory.length,
messageOnlyTokens: snapshotUsage.inputTokens + snapshotUsage.outputTokens,
messageOnlyTokens: snapshotRunMessageTokens,
fullContextTokens: totalTokens,
triggerTokens,
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
@@ -289,22 +298,25 @@ export async function buildCompressedHistory(
if (totalTokens <= triggerTokens) {
history = snapshotHistory
} else {
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
}
} else if (snapshot && staleSnapshot) {
if (totalTokens <= triggerTokens) {
history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
} else {
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
}
} else {
totalTokens = await estimateFullContextTokens(history, totalTokens)
const historyUsage = estimateUsageTokensFromMessages(history)
const historyMessageTokens = historyUsage.inputTokens + historyUsage.outputTokens
const runMessageTokens = Math.max(historyMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
totalTokens = await estimateLocalContextTokens(history, runMessageTokens)
emitContextUsage(totalTokens)
logger.info({
sessionId,
profile,
messages: history.length,
messageOnlyTokens: messageOnlyTotalTokens,
messageOnlyTokens: runMessageTokens,
fullContextTokens: totalTokens,
triggerTokens,
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
@@ -318,7 +330,7 @@ export async function buildCompressedHistory(
if (totalTokens <= triggerTokens) {
logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens)
} else {
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
}
}
@@ -342,8 +354,12 @@ export async function compressHistory(
sessionMap: Map<string, SessionState>,
modelContext: { model?: string | null; provider?: string | null } = {},
compressionConfig?: Partial<CompressorConfig>,
currentInputTokens = 0,
): Promise<ChatMessage[]> {
const msgCount = newMessagesOnly ? newMessagesOnly.length : history.length
const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
? Math.floor(currentInputTokens)
: 0
pushState(sessionMap, sessionId, 'compression.started', {
event: 'compression.started', message_count: msgCount, token_count: totalTokens,
})
@@ -353,14 +369,22 @@ export async function compressHistory(
try {
const session = getSession(sessionId)
const summarizerProfile = session?.profile || 'default'
const compressor = new ChatContextCompressor({ config: compressionConfig })
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
profile: session?.profile,
profile: summarizerProfile,
model: modelContext.model || session?.model,
provider: modelContext.provider || session?.provider,
workerKey: `${summarizerProfile}:compression:${sessionId}`,
})
const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit)
const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens
const resultUsage = estimateUsageTokensFromMessages(result.messages)
const resultMessageTokens = resultUsage.inputTokens + resultUsage.outputTokens
const compressedRunMessageTokens = Math.max(
compressedAfterTokens,
resultMessageTokens + currentRunInputTokens,
)
const compressedMeta: any = {
event: 'compression.completed' as const,
compressed: result.meta.compressed,
@@ -368,15 +392,15 @@ export async function compressHistory(
totalMessages: result.meta.totalMessages,
resultMessages: result.messages.length,
beforeTokens: totalTokens,
afterTokens: compressedAfterTokens,
afterTokens: compressedRunMessageTokens,
summaryTokens: result.meta.summaryTokenEstimate,
verbatimCount: result.meta.verbatimCount,
compressedStartIndex: result.meta.compressedStartIndex,
}
replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta)
logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)',
sessionId, result.messages.length, compressedAfterTokens, totalTokens)
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedAfterTokens, afterTokens)
sessionId, result.messages.length, compressedRunMessageTokens, totalTokens)
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedRunMessageTokens, afterTokens)
if (compressedContextTokens != null) {
compressedMeta.contextTokens = compressedContextTokens
}
@@ -403,6 +427,7 @@ export async function compressHistory(
resultMessages: msgCount,
beforeTokens: totalTokens,
afterTokens: totalTokens,
contextTokens: totalTokens,
summaryTokens: 0,
verbatimCount: msgCount,
compressedStartIndex: -1,
@@ -458,10 +483,12 @@ export async function forceCompressBridgeHistory(
}, '[chat-run-socket] bridge forced compression started')
const compressor = new ChatContextCompressor({ config: compressionConfig.compressor })
const summarizerProfile = session?.profile || profile || 'default'
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
profile: session?.profile || profile,
profile: summarizerProfile,
model: session?.model,
provider: session?.provider,
workerKey: `${summarizerProfile}:compression:${sessionId}`,
})
const compressedMessages = result.messages.map(m => {
const msg: any = { role: m.role, content: m.content }
@@ -18,7 +18,7 @@ import { convertHistoryFormat } from './message-format'
import { readSseFrames } from './sse-utils'
import { extractResponseText } from './response-utils'
import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream'
import { buildCompressedHistory, getOrCreateSession } from './compression'
import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, getOrCreateSession } from './compression'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
import { handleMessage } from './message-format'
import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
@@ -37,6 +37,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
let inputTokens: number
let outputTokens: number
let contextTokens: number | undefined
const snapshot = getCompressionSnapshot(sid)
if (snapshot && snapshot.lastMessageIndex >= 0 && snapshot.lastMessageIndex < messages.length) {
const newMessages = messages.slice(snapshot.lastMessageIndex + 1)
@@ -49,6 +50,20 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
inputTokens = usage.inputTokens
outputTokens = usage.outputTokens
}
try {
const session = getSession(sid)
const dbHistory = await buildDbHistory(sid, { excludeLastUser: false })
const snapshotHistory = await buildSnapshotAwareHistory(
sid,
session?.profile || 'default',
dbHistory,
{ model: session?.model, provider: session?.provider },
)
const contextUsage = estimateUsageTokensFromMessages(snapshotHistory)
contextTokens = contextUsage.inputTokens + contextUsage.outputTokens
} catch (err) {
logger.warn(err, '[chat-run-socket] failed to calculate snapshot-aware context tokens for session %s', sid)
}
logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length)
return {
@@ -57,6 +72,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
events: [],
inputTokens,
outputTokens,
contextTokens,
queue: [],
}
} catch (err) {
@@ -16,7 +16,6 @@ import {
contextTokensWithCachedOverhead,
estimateUsageTokensFromMessages,
getCachedBridgeContextOverhead,
updateContextTokenUsage,
updateMessageContextTokenUsage,
} from './usage'
import {
@@ -123,6 +122,66 @@ function cacheBridgeContext(state: SessionState, data: Record<string, unknown> |
}
}
function bridgeContextMatches(
state: SessionState,
expected: { profile: string; model?: string | null; provider?: string | null },
): boolean {
const context = state.bridgeContext
if (!context) return false
if (context.profile && context.profile !== expected.profile) return false
if (expected.model && context.model && context.model !== expected.model) return false
if (expected.provider && context.provider && context.provider !== expected.provider) return false
return true
}
async function ensureBridgeFixedContext(args: {
sessionId: string
profile: string
model?: string | null
provider?: string | null
instructions: string
state: SessionState
bridge: AgentBridgeClient
refresh?: boolean
}): Promise<number | undefined> {
const cached = bridgeContextMatches(args.state, args)
? getCachedBridgeContextOverhead(args.state)
: undefined
if (!args.refresh && cached != null) return cached
try {
const estimate = await args.bridge.contextEstimate(
args.sessionId,
[],
args.instructions,
args.profile,
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
)
cacheBridgeContext(args.state, estimate)
const fixedContextTokens = getCachedBridgeContextOverhead(args.state)
bridgeLogger.info({
sessionId: args.sessionId,
profile: args.profile,
model: args.model,
provider: args.provider,
toolCount: estimate.tool_count,
systemPromptChars: estimate.system_prompt_chars,
fixedContextTokens,
}, '[chat-run-socket] fixed context estimate')
return fixedContextTokens
} catch (err) {
bridgeLogger.warn({
err: err instanceof Error ? { message: err.message, name: err.name } : err,
sessionId: args.sessionId,
profile: args.profile,
model: args.model,
provider: args.provider,
cachedFixedContextTokens: cached,
}, '[chat-run-socket] fixed context estimate failed')
return cached
}
}
export async function handleBridgeRun(
nsp: ReturnType<Server['of']>,
socket: Socket,
@@ -195,6 +254,9 @@ export async function handleBridgeRun(
const displayInput = data.display_input === undefined ? input : data.display_input
const inputStr = displayInput == null ? '' : contentBlocksToString(displayInput)
const actualInputStr = contentBlocksToString(input)
const currentInputUsage = estimateUsageTokensFromMessages([{ role: 'user', content: actualInputStr }])
const currentInputTokens = currentInputUsage.inputTokens
const shouldPersistUserMessage = !skipUserMessage && displayInput !== null
const displayRole = data.display_role === 'command' ? 'command' : 'user'
let messageId: number | string | undefined
@@ -257,33 +319,32 @@ export async function handleBridgeRun(
emit,
sessionMap,
{ model: resolvedModel, provider: resolvedProvider },
async (messages) => {
const cachedOverhead = getCachedBridgeContextOverhead(state)
if (cachedOverhead != null) {
const messageUsage = estimateUsageTokensFromMessages(messages)
return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens
}
const estimate = await bridge.contextEstimate(
session_id,
messages,
fullInstructions,
async (_messages, localMessageTokens) => {
const fixedContextTokens = await ensureBridgeFixedContext({
sessionId: session_id,
profile,
{ model: resolvedModel, provider: resolvedProvider },
)
cacheBridgeContext(state, estimate)
model: resolvedModel,
provider: resolvedProvider,
instructions: fullInstructions,
state,
bridge,
refresh: true,
})
const contextTokens = fixedContextTokens == null
? localMessageTokens
: fixedContextTokens + localMessageTokens
bridgeLogger.info({
sessionId: session_id,
profile,
model: resolvedModel,
provider: resolvedProvider,
messages: estimate.message_count,
toolCount: estimate.tool_count,
systemPromptChars: estimate.system_prompt_chars,
fixedContextTokens: estimate.fixed_context_tokens,
fullContextTokens: estimate.token_count,
}, '[chat-run-socket] full context estimate')
return estimate.token_count
fixedContextTokens,
messageTokens: localMessageTokens,
contextTokens,
}, '[chat-run-socket] local context estimate')
return contextTokens
},
currentInputTokens,
)
const bridgeHistory = history
@@ -349,6 +410,8 @@ export async function handleBridgeRun(
dequeueNextQueuedRun,
fullInstructions,
{ model: resolvedModel, provider: resolvedProvider },
currentInputTokens,
shouldPersistUserMessage && displayRole === 'user',
data.model_groups,
)
if (chunk.done) break
@@ -417,61 +480,68 @@ async function refreshFinalContextUsage(args: {
)
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
if (getCachedBridgeContextOverhead(args.state) != null) {
const contextTokens = updateMessageContextTokenUsage(
args.sessionId,
args.state,
args.emit,
finalMessageTokens,
args.usage,
)
bridgeLogger.info({
sessionId: args.sessionId,
profile: args.profile,
model: args.model,
provider: args.provider,
messages: finalHistory.length,
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
messageTokens: finalMessageTokens,
fullContextTokens: contextTokens,
}, '[chat-run-socket] final cached context estimate')
return contextTokens
}
const estimate = await args.bridge.contextEstimate(
await ensureBridgeFixedContext({
sessionId: args.sessionId,
profile: args.profile,
model: args.model,
provider: args.provider,
instructions: args.instructions,
state: args.state,
bridge: args.bridge,
})
const contextTokens = updateMessageContextTokenUsage(
args.sessionId,
finalHistory,
args.instructions,
args.profile,
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
args.state,
args.emit,
finalMessageTokens,
args.usage,
)
cacheBridgeContext(args.state, estimate)
const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0
? Math.floor(estimate.token_count)
: undefined
if (contextTokens == null) return args.state.contextTokens
updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage)
bridgeLogger.info({
sessionId: args.sessionId,
profile: args.profile,
model: args.model,
provider: args.provider,
messages: estimate.message_count,
toolCount: estimate.tool_count,
systemPromptChars: estimate.system_prompt_chars,
fullContextTokens: contextTokens,
}, '[chat-run-socket] final full context estimate')
messages: finalHistory.length,
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
messageTokens: finalMessageTokens,
contextTokens,
}, '[chat-run-socket] final local context estimate')
return contextTokens
} catch (err) {
bridgeLogger.warn({
err: err instanceof Error ? { message: err.message, name: err.name } : err,
sessionId: args.sessionId,
profile: args.profile,
}, '[chat-run-socket] final full context estimate failed')
}, '[chat-run-socket] final local context estimate failed')
return args.state.contextTokens
}
}
async function estimateSnapshotAwareMessageTokens(args: {
sessionId: string
profile: string
model?: string | null
provider?: string | null
currentInputTokens?: number
currentInputIncludedInDb?: boolean
}): Promise<{ messageTokens: number; messages: number }> {
const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
const snapshotHistory = await buildSnapshotAwareHistory(
args.sessionId,
args.profile,
dbHistory,
{ model: args.model, provider: args.provider },
)
const usage = estimateUsageTokensFromMessages(snapshotHistory)
const extraInputTokens = args.currentInputIncludedInDb
? 0
: finiteToken(args.currentInputTokens) ?? 0
return {
messageTokens: usage.inputTokens + usage.outputTokens + extraInputTokens,
messages: snapshotHistory.length,
}
}
async function applyBridgeChunkAsync(
nsp: ReturnType<Server['of']>,
socket: Socket,
@@ -486,6 +556,8 @@ async function applyBridgeChunkAsync(
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
instructions: string,
modelContext: { model?: string | null; provider?: string | null },
currentInputTokens = 0,
currentInputIncludedInDb = true,
modelGroups?: RunModelGroup[],
): Promise<void> {
if (state.activeRunMarker !== runMarker) {
@@ -505,11 +577,19 @@ async function applyBridgeChunkAsync(
if (evType === 'bridge.context.ready') {
cacheBridgeContext(state, ev)
const usage = await calcAndUpdateUsage(sessionId, state, emit)
const snapshotAware = await estimateSnapshotAwareMessageTokens({
sessionId,
profile,
model: modelContext.model,
provider: modelContext.provider,
currentInputTokens,
currentInputIncludedInDb,
})
updateMessageContextTokenUsage(
sessionId,
state,
emit,
usage.inputTokens + usage.outputTokens,
snapshotAware.messageTokens,
usage,
)
} else if (evType === 'tool.started') {
@@ -646,17 +726,22 @@ async function applyBridgeChunkAsync(
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens
const tokenCount = typeof ev.approx_tokens === 'number' && Number.isFinite(ev.approx_tokens) && ev.approx_tokens > 0
? ev.approx_tokens
: messageOnlyTokens
const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
? Math.floor(currentInputTokens)
: 0
const runMessageTokens = messageOnlyTokens + runInputTokens
const tokenCount = contextTokensWithCachedOverhead(state, runMessageTokens)
bridgeLogger.info({
sessionId,
profile,
bridgeMessages: ev.message_count,
dbMessages: bridgeHistory.length,
messageOnlyTokens,
fullContextTokens: tokenCount,
source: typeof ev.approx_tokens === 'number' ? 'bridge' : 'message-only-fallback',
currentInputTokens: runInputTokens,
fixedContextTokens: state.bridgeContext?.fixedContextTokens,
contextTokens: tokenCount,
bridgeApproxTokens: ev.approx_tokens,
source: 'local',
}, '[chat-run-socket] bridge compression token estimate')
const payload = {
event: 'compression.started',
@@ -674,7 +759,7 @@ async function applyBridgeChunkAsync(
sessionId,
profile,
ev.messages as ChatMessage[],
typeof ev.approx_tokens === 'number' ? ev.approx_tokens : undefined,
tokenCount,
)
state.bridgeCompressionResults = state.bridgeCompressionResults || {}
state.bridgeCompressionResults[String(ev.request_id)] = compressed
@@ -689,11 +774,16 @@ async function applyBridgeChunkAsync(
const compressionResult = ev.request_id
? state.bridgeCompressionResults?.[String(ev.request_id)]
: undefined
const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens)
const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null
? contextTokensWithCachedOverhead(state, messageAfterTokens)
: bridgeAfterContextTokens ?? messageAfterTokens
const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
? Math.floor(currentInputTokens)
: 0
const messageAfterTokensWithInput = messageAfterTokens != null
? messageAfterTokens + runInputTokens
: undefined
const afterContextTokens = messageAfterTokensWithInput != null
? contextTokensWithCachedOverhead(state, messageAfterTokensWithInput)
: undefined
const payload = {
event: 'compression.completed',
run_id: chunk.run_id,
@@ -703,7 +793,7 @@ async function applyBridgeChunkAsync(
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
afterTokens: messageAfterTokens ?? bridgeAfterContextTokens,
afterTokens: messageAfterTokensWithInput,
contextTokens: afterContextTokens,
summaryTokens: compressionResult?.summaryTokens,
verbatimCount: compressionResult?.verbatimCount,
@@ -716,10 +806,8 @@ async function applyBridgeChunkAsync(
replaceState(sessionMap, sessionId, 'compression.completed', payload)
emit('compression.completed', payload)
const usage = await calcAndUpdateUsage(sessionId, state, emit)
if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) {
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage)
} else {
updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage)
if (messageAfterTokensWithInput != null) {
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokensWithInput, usage)
}
} else if (evType === 'bridge.compression.failed') {
const payload = {