fix context token resume (#1039)

This commit is contained in:
ekko
2026-05-26 16:32:07 +08:00
committed by GitHub
parent e686f0277a
commit ad1cab277a
13 changed files with 959 additions and 203 deletions
+52 -29
View File
@@ -85,6 +85,15 @@ export interface Session {
workspace?: string | null workspace?: string | null
} }
interface CompressionState {
compressing: boolean
messageCount: number
beforeTokens: number
afterTokens: number
compressed: boolean | null
error?: string
}
function uid(): string { function uid(): string {
return Date.now().toString(36) + Math.random().toString(36).slice(2, 8) return Date.now().toString(36) + Math.random().toString(36).slice(2, 8)
} }
@@ -272,6 +281,8 @@ function mapHermesSession(s: SessionSummary): Session {
model: s.model, model: s.model,
provider: s.provider || (s as any).billing_provider || '', provider: s.provider || (s as any).billing_provider || '',
messageCount: s.message_count, messageCount: s.message_count,
inputTokens: s.input_tokens,
outputTokens: s.output_tokens,
endedAt: s.ended_at != null ? Math.round(s.ended_at * 1000) : null, endedAt: s.ended_at != null ? Math.round(s.ended_at * 1000) : null,
lastActiveAt: s.last_active != null ? Math.round(s.last_active * 1000) : undefined, lastActiveAt: s.last_active != null ? Math.round(s.last_active * 1000) : undefined,
workspace: s.workspace || null, workspace: s.workspace || null,
@@ -405,18 +416,20 @@ export const useChatStore = defineStore('chat', () => {
const isLoadingMessages = ref(false) const isLoadingMessages = ref(false)
const isRunActive = computed(() => isStreaming.value) const isRunActive = computed(() => isStreaming.value)
// Compression state // Compression state is scoped per session because sockets can stay joined to
const compressionState = ref<{ // background sessions while another chat is active.
compressing: boolean const compressionStates = ref<Map<string, CompressionState>>(new Map())
messageCount: number const compressionState = computed<CompressionState | null>(() => {
beforeTokens: number const sid = activeSessionId.value
afterTokens: number return sid ? compressionStates.value.get(sid) || null : null
compressed: boolean | null })
error?: string
} | null>(null)
function setCompressionState(state: typeof compressionState.value) { function setCompressionState(sessionId: string | null | undefined, state: CompressionState | null) {
compressionState.value = state if (!sessionId) return
const next = new Map(compressionStates.value)
if (state) next.set(sessionId, state)
else next.delete(sessionId)
compressionStates.value = next
} }
const abortState = ref<{ const abortState = ref<{
@@ -438,11 +451,12 @@ export const useChatStore = defineStore('chat', () => {
} }
function clearActiveSession() { function clearActiveSession() {
const sid = activeSessionId.value
activeSessionId.value = null activeSessionId.value = null
activeSession.value = null activeSession.value = null
focusMessageId.value = null focusMessageId.value = null
setAbortState(null) setAbortState(null)
setCompressionState(null) setCompressionState(sid, null)
removeItem(storageKey()) removeItem(storageKey())
} }
@@ -453,10 +467,14 @@ export const useChatStore = defineStore('chat', () => {
const fresh = list.map(mapHermesSession) const fresh = list.map(mapHermesSession)
// Preserve already-loaded messages for sessions that are still present, // Preserve already-loaded messages for sessions that are still present,
// so we don't blow away the active session's messages on refresh. // so we don't blow away the active session's messages on refresh.
const msgsByIdBefore = new Map(sessions.value.map(s => [s.id, s.messages])) const runtimeByIdBefore = new Map(sessions.value.map(s => [s.id, {
messages: s.messages,
contextTokens: s.contextTokens,
}]))
for (const s of fresh) { for (const s of fresh) {
const prev = msgsByIdBefore.get(s.id) const prev = runtimeByIdBefore.get(s.id)
if (prev && prev.length) s.messages = prev if (prev?.messages?.length) s.messages = prev.messages
if (prev?.contextTokens != null) s.contextTokens = prev.contextTokens
} }
sessions.value = fresh sessions.value = fresh
@@ -594,6 +612,7 @@ export const useChatStore = defineStore('chat', () => {
} else if (!data.isWorking) { } else if (!data.isWorking) {
setAbortState(null) setAbortState(null)
} }
if (!data.isWorking) setCompressionState(sessionId, null)
if (data.inputTokens != null) target.inputTokens = data.inputTokens if (data.inputTokens != null) target.inputTokens = data.inputTokens
if (data.outputTokens != null) target.outputTokens = data.outputTokens if (data.outputTokens != null) target.outputTokens = data.outputTokens
if ((data as any).contextTokens != null) target.contextTokens = (data as any).contextTokens if ((data as any).contextTokens != null) target.contextTokens = (data as any).contextTokens
@@ -613,7 +632,7 @@ export const useChatStore = defineStore('chat', () => {
for (const evt of data.events) { for (const evt of data.events) {
const e = evt.data as any const e = evt.data as any
if (e.event === 'compression.started') { if (e.event === 'compression.started') {
setCompressionState({ setCompressionState(sessionId, {
compressing: true, compressing: true,
messageCount: e.message_count || 0, messageCount: e.message_count || 0,
beforeTokens: e.token_count || 0, beforeTokens: e.token_count || 0,
@@ -622,7 +641,7 @@ export const useChatStore = defineStore('chat', () => {
}) })
} else if (e.event === 'compression.completed') { } else if (e.event === 'compression.completed') {
const afterTokens = e.contextTokens || e.afterTokens || 0 const afterTokens = e.contextTokens || e.afterTokens || 0
setCompressionState({ setCompressionState(sessionId, {
compressing: false, compressing: false,
messageCount: e.totalMessages || 0, messageCount: e.totalMessages || 0,
beforeTokens: e.beforeTokens || 0, beforeTokens: e.beforeTokens || 0,
@@ -1385,6 +1404,7 @@ export const useChatStore = defineStore('chat', () => {
} else if (!data.isWorking) { } else if (!data.isWorking) {
setAbortState(null) setAbortState(null)
} }
if (!data.isWorking) setCompressionState(sid, null)
if (data.inputTokens != null) target.inputTokens = data.inputTokens if (data.inputTokens != null) target.inputTokens = data.inputTokens
if (data.outputTokens != null) target.outputTokens = data.outputTokens if (data.outputTokens != null) target.outputTokens = data.outputTokens
@@ -1407,7 +1427,7 @@ export const useChatStore = defineStore('chat', () => {
const e = evt.data as RunEvent const e = evt.data as RunEvent
switch (e.event) { switch (e.event) {
case 'compression.started': case 'compression.started':
setCompressionState({ setCompressionState(sid, {
compressing: true, compressing: true,
messageCount: (e as any).message_count || 0, messageCount: (e as any).message_count || 0,
beforeTokens: (e as any).token_count || 0, beforeTokens: (e as any).token_count || 0,
@@ -1417,7 +1437,7 @@ export const useChatStore = defineStore('chat', () => {
break break
case 'compression.completed': { case 'compression.completed': {
const afterTokens = (e as any).contextTokens || (e as any).afterTokens || 0 const afterTokens = (e as any).contextTokens || (e as any).afterTokens || 0
setCompressionState({ setCompressionState(sid, {
compressing: false, compressing: false,
messageCount: (e as any).totalMessages || 0, messageCount: (e as any).totalMessages || 0,
beforeTokens: (e as any).beforeTokens || 0, beforeTokens: (e as any).beforeTokens || 0,
@@ -1474,7 +1494,7 @@ export const useChatStore = defineStore('chat', () => {
case 'run.started': case 'run.started':
clearAgentEventMessages(sid) clearAgentEventMessages(sid)
setAbortState(null) setAbortState(null)
setCompressionState(null) setCompressionState(sid, null)
runProducedAssistantText = false runProducedAssistantText = false
runHadToolActivity = false runHadToolActivity = false
closeStreamingAssistant() closeStreamingAssistant()
@@ -1502,7 +1522,7 @@ export const useChatStore = defineStore('chat', () => {
} }
case 'compression.started': { case 'compression.started': {
setCompressionState({ setCompressionState(sid, {
compressing: true, compressing: true,
messageCount: (evt as any).message_count || 0, messageCount: (evt as any).message_count || 0,
beforeTokens: (evt as any).token_count || 0, beforeTokens: (evt as any).token_count || 0,
@@ -1514,7 +1534,7 @@ export const useChatStore = defineStore('chat', () => {
case 'compression.completed': { case 'compression.completed': {
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0 const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
setCompressionState({ setCompressionState(sid, {
compressing: false, compressing: false,
messageCount: (evt as any).totalMessages || 0, messageCount: (evt as any).totalMessages || 0,
beforeTokens: (evt as any).beforeTokens || 0, beforeTokens: (evt as any).beforeTokens || 0,
@@ -1528,8 +1548,9 @@ export const useChatStore = defineStore('chat', () => {
} }
// Auto-clear after 5s // Auto-clear after 5s
setTimeout(() => { setTimeout(() => {
if (compressionState.value && !compressionState.value.compressing) { const state = compressionStates.value.get(sid)
setCompressionState(null) if (state && !state.compressing) {
setCompressionState(sid, null)
} }
}, 5000) }, 5000)
break break
@@ -1966,7 +1987,7 @@ export const useChatStore = defineStore('chat', () => {
case 'run.started': case 'run.started':
clearAgentEventMessages(sid) clearAgentEventMessages(sid)
setAbortState(null) setAbortState(null)
setCompressionState(null) setCompressionState(sid, null)
runProducedAssistantText = false runProducedAssistantText = false
runHadToolActivity = false runHadToolActivity = false
closeStreamingAssistant() closeStreamingAssistant()
@@ -1979,7 +2000,7 @@ export const useChatStore = defineStore('chat', () => {
break break
case 'compression.started': { case 'compression.started': {
setCompressionState({ setCompressionState(sid, {
compressing: true, compressing: true,
messageCount: (evt as any).message_count || 0, messageCount: (evt as any).message_count || 0,
beforeTokens: (evt as any).token_count || 0, beforeTokens: (evt as any).token_count || 0,
@@ -1991,7 +2012,7 @@ export const useChatStore = defineStore('chat', () => {
case 'compression.completed': { case 'compression.completed': {
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0 const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
setCompressionState({ setCompressionState(sid, {
compressing: false, compressing: false,
messageCount: (evt as any).totalMessages || 0, messageCount: (evt as any).totalMessages || 0,
beforeTokens: (evt as any).beforeTokens || 0, beforeTokens: (evt as any).beforeTokens || 0,
@@ -2004,8 +2025,9 @@ export const useChatStore = defineStore('chat', () => {
if (target) target.contextTokens = (evt as any).contextTokens if (target) target.contextTokens = (evt as any).contextTokens
} }
setTimeout(() => { setTimeout(() => {
if (compressionState.value && !compressionState.value.compressing) { const state = compressionStates.value.get(sid)
setCompressionState(null) if (state && !state.compressing) {
setCompressionState(sid, null)
} }
}, 5000) }, 5000)
break break
@@ -2461,6 +2483,7 @@ export const useChatStore = defineStore('chat', () => {
} else if (!data.isWorking) { } else if (!data.isWorking) {
setAbortState(null) setAbortState(null)
} }
if (!data.isWorking) setCompressionState(sid, null)
if (data.messages?.length && activeSession.value) { if (data.messages?.length && activeSession.value) {
activeSession.value.messages = mapHermesMessages(data.messages as any[]) activeSession.value.messages = mapHermesMessages(data.messages as any[])
} }
@@ -79,6 +79,7 @@ export interface SummarizerOptions {
profile?: string profile?: string
model?: string | null model?: string | null
provider?: string | null provider?: string | null
workerKey?: string
} }
// ─── Token counting ───────────────────────────────────── // ─── Token counting ─────────────────────────────────────
@@ -454,6 +455,7 @@ export async function callSummarizer(
const bridge = new AgentBridgeClient({ timeoutMs: timeoutMs + 15_000 }) const bridge = new AgentBridgeClient({ timeoutMs: timeoutMs + 15_000 })
const sessionId = `compress_${Date.now().toString(36)}_${randomUUID().replace(/-/g, '').slice(0, 12)}` const sessionId = `compress_${Date.now().toString(36)}_${randomUUID().replace(/-/g, '').slice(0, 12)}`
const workerKey = options.workerKey || `${profile}:compression:${sessionId}`
try { try {
const result = await bridge.request<AgentBridgeRunResult>({ const result = await bridge.request<AgentBridgeRunResult>({
@@ -462,6 +464,7 @@ export async function callSummarizer(
message: prompt, message: prompt,
conversation_history: convHistory, conversation_history: convHistory,
profile, profile,
worker_key: workerKey,
source: 'api_server', source: 'api_server',
wait: true, wait: true,
timeout: Math.ceil(timeoutMs / 1000), timeout: Math.ceil(timeoutMs / 1000),
@@ -482,7 +485,7 @@ export async function callSummarizer(
if (!output) throw new Error('Empty summarization response') if (!output) throw new Error('Empty summarization response')
return output return output
} finally { } finally {
await bridge.destroy(sessionId, profile).catch(() => undefined) await bridge.destroy(sessionId, profile, workerKey).catch(() => undefined)
} }
} }
@@ -166,7 +166,7 @@ export class AgentBridgeClient {
private summarizePayload(payload: Record<string, unknown>): Record<string, unknown> { private summarizePayload(payload: Record<string, unknown>): Record<string, unknown> {
const action = String(payload.action || '') const action = String(payload.action || '')
const summary: Record<string, unknown> = { action } const summary: Record<string, unknown> = { action }
for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile']) { for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile', 'worker_key']) {
if (payload[key] != null) summary[key] = payload[key] if (payload[key] != null) summary[key] = payload[key]
} }
if (Array.isArray(payload.conversation_history)) summary.conversation_history_count = payload.conversation_history.length if (Array.isArray(payload.conversation_history)) summary.conversation_history_count = payload.conversation_history.length
@@ -569,11 +569,12 @@ export class AgentBridgeClient {
}) })
} }
destroy(sessionId: string, profile?: string): Promise<AgentBridgeResponse> { destroy(sessionId: string, profile?: string, workerKey?: string): Promise<AgentBridgeResponse> {
return this.request({ return this.request({
action: 'destroy', action: 'destroy',
session_id: sessionId, session_id: sessionId,
...(profile ? { profile } : {}), ...(profile ? { profile } : {}),
...(workerKey ? { worker_key: workerKey } : {}),
}) })
} }
@@ -2200,7 +2200,8 @@ class WorkerProcess:
STARTUP_TIMEOUT_SECONDS = 120 STARTUP_TIMEOUT_SECONDS = 120
REQUEST_TIMEOUT_SECONDS = 120 REQUEST_TIMEOUT_SECONDS = 120
def __init__(self, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None: def __init__(self, key: str, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None:
self.key = key or profile or "default"
self.profile = profile or "default" self.profile = profile or "default"
self.endpoint = endpoint self.endpoint = endpoint
self.agent_root = agent_root self.agent_root = agent_root
@@ -2263,14 +2264,14 @@ class WorkerProcess:
for line in proc.stderr: for line in proc.stderr:
text = line.rstrip() text = line.rstrip()
if text: if text:
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True) print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.profile}").start() threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.key}").start()
def _wait_ready(self) -> None: def _wait_ready(self) -> None:
proc = self.process proc = self.process
if proc is None or proc.stdout is None: if proc is None or proc.stdout is None:
raise RuntimeError(f"profile worker {self.profile} did not start") raise RuntimeError(f"profile worker {self.key} did not start")
lines: queue.Queue[str | None] = queue.Queue() lines: queue.Queue[str | None] = queue.Queue()
ready_event = threading.Event() ready_event = threading.Event()
@@ -2281,17 +2282,17 @@ class WorkerProcess:
if ready_event.is_set(): if ready_event.is_set():
text = line.rstrip() text = line.rstrip()
if text: if text:
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True) print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
else: else:
lines.put(line) lines.put(line)
finally: finally:
lines.put(None) lines.put(None)
threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.profile}").start() threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.key}").start()
deadline = time.time() + self.STARTUP_TIMEOUT_SECONDS deadline = time.time() + self.STARTUP_TIMEOUT_SECONDS
while time.time() < deadline: while time.time() < deadline:
if proc.poll() is not None: if proc.poll() is not None:
raise RuntimeError(f"profile worker {self.profile} exited before ready") raise RuntimeError(f"profile worker {self.key} exited before ready")
try: try:
line = lines.get(timeout=0.1) line = lines.get(timeout=0.1)
except queue.Empty: except queue.Empty:
@@ -2301,7 +2302,7 @@ class WorkerProcess:
continue continue
text = line.strip() text = line.strip()
if text: if text:
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True) print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
try: try:
data = json.loads(text) data = json.loads(text)
if data.get("event") == "ready": if data.get("event") == "ready":
@@ -2310,7 +2311,7 @@ class WorkerProcess:
except Exception: except Exception:
pass pass
self.stop() self.stop()
raise RuntimeError(f"profile worker {self.profile} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s") raise RuntimeError(f"profile worker {self.key} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s")
def stop(self) -> None: def stop(self) -> None:
with self._lock: with self._lock:
@@ -2337,8 +2338,8 @@ class WorkerProcess:
return _send_bridge_request(self.endpoint, req, self.REQUEST_TIMEOUT_SECONDS) return _send_bridge_request(self.endpoint, req, self.REQUEST_TIMEOUT_SECONDS)
def _worker_endpoint(profile: str) -> str: def _worker_endpoint(key: str) -> str:
safe = hashlib.sha256(profile.encode("utf-8")).hexdigest()[:16] safe = hashlib.sha256(key.encode("utf-8")).hexdigest()[:16]
if os.name == "nt": if os.name == "nt":
port_base = int(os.environ.get("HERMES_AGENT_BRIDGE_WORKER_PORT_BASE", "18780")) port_base = int(os.environ.get("HERMES_AGENT_BRIDGE_WORKER_PORT_BASE", "18780"))
return f"tcp://127.0.0.1:{port_base + int(safe[:4], 16) % 1000}" return f"tcp://127.0.0.1:{port_base + int(safe[:4], 16) % 1000}"
@@ -2533,11 +2534,17 @@ class BridgeBroker:
self.hermes_home = hermes_home self.hermes_home = hermes_home
self._workers: dict[str, WorkerProcess] = {} self._workers: dict[str, WorkerProcess] = {}
self._run_profile: dict[str, str] = {} self._run_profile: dict[str, str] = {}
self._run_worker_key: dict[str, str] = {}
self._running_run_profile: dict[str, str] = {} self._running_run_profile: dict[str, str] = {}
self._running_run_worker_key: dict[str, str] = {}
self._session_profile: dict[str, str] = {} self._session_profile: dict[str, str] = {}
self._session_worker_key: dict[str, str] = {}
self._approval_profile: dict[str, str] = {} self._approval_profile: dict[str, str] = {}
self._approval_worker_key: dict[str, str] = {}
self._clarify_profile: dict[str, str] = {} self._clarify_profile: dict[str, str] = {}
self._clarify_worker_key: dict[str, str] = {}
self._compression_profile: dict[str, str] = {} self._compression_profile: dict[str, str] = {}
self._compression_worker_key: dict[str, str] = {}
self._lock = threading.RLock() self._lock = threading.RLock()
self._stop = threading.Event() self._stop = threading.Event()
self._last_gc = time.time() self._last_gc = time.time()
@@ -2546,58 +2553,73 @@ class BridgeBroker:
profile = str(value or "").strip() profile = str(value or "").strip()
return profile or "default" return profile or "default"
def _worker_for_profile(self, profile: str) -> WorkerProcess: def _normalize_worker_key(self, profile: str, value: Any = None) -> str:
worker_key = str(value or "").strip()
return worker_key or profile
def _worker_for_profile(self, profile: str, worker_key: str | None = None) -> WorkerProcess:
profile = self._normalize_profile(profile) profile = self._normalize_profile(profile)
key = self._normalize_worker_key(profile, worker_key)
with self._lock: with self._lock:
worker = self._workers.get(profile) worker = self._workers.get(key)
if worker is None: if worker is None:
worker = WorkerProcess(profile, _worker_endpoint(profile), self.agent_root, self.hermes_home) worker = WorkerProcess(key, profile, _worker_endpoint(key), self.agent_root, self.hermes_home)
self._workers[profile] = worker self._workers[key] = worker
return worker return worker
def _profile_for_run(self, run_id: str) -> str: def _route_for_run(self, run_id: str) -> tuple[str, str | None]:
with self._lock: with self._lock:
profile = self._run_profile.get(run_id) profile = self._run_profile.get(run_id)
worker_key = self._run_worker_key.get(run_id)
if not profile: if not profile:
raise KeyError(f"unknown run: {run_id}") raise KeyError(f"unknown run: {run_id}")
return profile return profile, worker_key
def _profile_for_session(self, session_id: str, fallback_profile: Any = None) -> str: def _route_for_session(self, session_id: str, fallback_profile: Any = None, worker_key: Any = None) -> tuple[str, str | None]:
with self._lock: with self._lock:
profile = self._session_profile.get(session_id) profile = self._session_profile.get(session_id)
stored_worker_key = self._session_worker_key.get(session_id)
if not profile: if not profile:
fallback = self._normalize_profile(fallback_profile) fallback = self._normalize_profile(fallback_profile)
if fallback_profile is not None and fallback: if fallback_profile is not None and fallback:
return fallback return fallback, self._normalize_worker_key(fallback, worker_key)
raise KeyError(f"unknown session: {session_id}") raise KeyError(f"unknown session: {session_id}")
return profile return profile, self._normalize_worker_key(profile, worker_key) if worker_key is not None else stored_worker_key
def _record_response_routes(self, profile: str, resp: dict[str, Any]) -> None: def _record_response_routes(self, profile: str, worker_key: str, resp: dict[str, Any]) -> None:
run_id = str(resp.get("run_id") or "") run_id = str(resp.get("run_id") or "")
session_id = str(resp.get("session_id") or "") session_id = str(resp.get("session_id") or "")
with self._lock: with self._lock:
if run_id: if run_id:
self._run_profile[run_id] = profile self._run_profile[run_id] = profile
self._run_worker_key[run_id] = worker_key
if resp.get("status") == "running": if resp.get("status") == "running":
self._running_run_profile[run_id] = profile self._running_run_profile[run_id] = profile
self._running_run_worker_key[run_id] = worker_key
else: else:
self._running_run_profile.pop(run_id, None) self._running_run_profile.pop(run_id, None)
self._running_run_worker_key.pop(run_id, None)
if session_id: if session_id:
self._session_profile[session_id] = profile self._session_profile[session_id] = profile
self._session_worker_key[session_id] = worker_key
for event in resp.get("events") or []: for event in resp.get("events") or []:
if not isinstance(event, dict): if not isinstance(event, dict):
continue continue
approval_id = str(event.get("approval_id") or "") approval_id = str(event.get("approval_id") or "")
if approval_id: if approval_id:
self._approval_profile[approval_id] = profile self._approval_profile[approval_id] = profile
self._approval_worker_key[approval_id] = worker_key
clarify_id = str(event.get("clarify_id") or "") clarify_id = str(event.get("clarify_id") or "")
if clarify_id: if clarify_id:
self._clarify_profile[clarify_id] = profile self._clarify_profile[clarify_id] = profile
self._clarify_worker_key[clarify_id] = worker_key
request_id = str(event.get("request_id") or "") request_id = str(event.get("request_id") or "")
if event.get("event") == "bridge.compression.requested" and request_id: if event.get("event") == "bridge.compression.requested" and request_id:
self._compression_profile[request_id] = profile self._compression_profile[request_id] = profile
self._compression_worker_key[request_id] = worker_key
if event.get("event") in {"bridge.compression.completed", "bridge.compression.failed"} and request_id: if event.get("event") in {"bridge.compression.completed", "bridge.compression.failed"} and request_id:
self._compression_profile.pop(request_id, None) self._compression_profile.pop(request_id, None)
self._compression_worker_key.pop(request_id, None)
def stop(self) -> None: def stop(self) -> None:
self._stop.set() self._stop.set()
@@ -2605,20 +2627,29 @@ class BridgeBroker:
workers = list(self._workers.values()) workers = list(self._workers.values())
self._workers.clear() self._workers.clear()
self._run_profile.clear() self._run_profile.clear()
self._run_worker_key.clear()
self._running_run_profile.clear() self._running_run_profile.clear()
self._running_run_worker_key.clear()
self._session_profile.clear() self._session_profile.clear()
self._session_worker_key.clear()
self._approval_profile.clear() self._approval_profile.clear()
self._approval_worker_key.clear()
self._clarify_profile.clear() self._clarify_profile.clear()
self._clarify_worker_key.clear()
self._compression_profile.clear() self._compression_profile.clear()
self._compression_worker_key.clear()
for worker in workers: for worker in workers:
worker.stop() worker.stop()
def _forward(self, profile: str, req: dict[str, Any]) -> dict[str, Any]: def _forward(self, profile: str, req: dict[str, Any], worker_key: str | None = None) -> dict[str, Any]:
worker = self._worker_for_profile(profile) profile = self._normalize_profile(profile)
key = self._normalize_worker_key(profile, worker_key)
worker = self._worker_for_profile(profile, key)
forwarded = dict(req) forwarded = dict(req)
forwarded["profile"] = profile forwarded["profile"] = profile
forwarded.pop("worker_key", None)
resp = worker.request(forwarded) resp = worker.request(forwarded)
self._record_response_routes(profile, resp) self._record_response_routes(profile, key, resp)
return resp return resp
def handle(self, req: dict[str, Any]) -> dict[str, Any]: def handle(self, req: dict[str, Any]) -> dict[str, Any]:
@@ -2629,15 +2660,16 @@ class BridgeBroker:
if action == "ping": if action == "ping":
with self._lock: with self._lock:
worker_details = { worker_details = {
profile: { key: {
"running": worker.running, "running": worker.running,
"pid": worker.pid, "pid": worker.pid,
"endpoint": worker.endpoint, "endpoint": worker.endpoint,
"profile": getattr(worker, "profile", key),
"last_used_at": worker.last_used_at, "last_used_at": worker.last_used_at,
} }
for profile, worker in self._workers.items() for key, worker in self._workers.items()
} }
workers = {profile: details["running"] for profile, details in worker_details.items()} workers = {key: details["running"] for key, details in worker_details.items()}
sessions_by_profile: dict[str, int] = {} sessions_by_profile: dict[str, int] = {}
for profile in self._session_profile.values(): for profile in self._session_profile.values():
sessions_by_profile[profile] = sessions_by_profile.get(profile, 0) + 1 sessions_by_profile[profile] = sessions_by_profile.get(profile, 0) + 1
@@ -2664,29 +2696,32 @@ class BridgeBroker:
if action == "worker_ping": if action == "worker_ping":
profile = self._normalize_profile(req.get("profile")) profile = self._normalize_profile(req.get("profile"))
resp = self._forward(profile, {"action": "ping"}) worker_key = self._normalize_worker_key(profile, req.get("worker_key"))
resp = self._forward(profile, {"action": "ping"}, worker_key)
resp["worker_profile"] = profile resp["worker_profile"] = profile
resp["worker_key"] = worker_key
return resp return resp
if action == "chat": if action == "chat":
profile = self._normalize_profile(req.get("profile")) profile = self._normalize_profile(req.get("profile"))
return self._forward(profile, req) return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
if action == "context_estimate": if action == "context_estimate":
profile = self._normalize_profile(req.get("profile")) profile = self._normalize_profile(req.get("profile"))
return self._forward(profile, req) return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
if action in {"get_result", "get_output"}: if action in {"get_result", "get_output"}:
profile = self._profile_for_run(str(req.get("run_id") or "")) profile, worker_key = self._route_for_run(str(req.get("run_id") or ""))
return self._forward(profile, req) return self._forward(profile, req, worker_key)
if action in {"interrupt", "steer", "command", "goal_evaluate", "goal_pause", "status", "get_history", "destroy"}: if action in {"interrupt", "steer", "command", "goal_evaluate", "goal_pause", "status", "get_history", "destroy"}:
session_id = str(req.get("session_id") or "") session_id = str(req.get("session_id") or "")
profile = self._profile_for_session(session_id, req.get("profile")) profile, worker_key = self._route_for_session(session_id, req.get("profile"), req.get("worker_key") if "worker_key" in req else None)
resp = self._forward(profile, req) resp = self._forward(profile, req, worker_key)
if action == "destroy": if action == "destroy":
with self._lock: with self._lock:
self._session_profile.pop(session_id, None) self._session_profile.pop(session_id, None)
self._session_worker_key.pop(session_id, None)
return resp return resp
if action == "approval_respond": if action == "approval_respond":
@@ -2695,9 +2730,10 @@ class BridgeBroker:
raise ValueError("approval_id is required") raise ValueError("approval_id is required")
with self._lock: with self._lock:
profile = self._approval_profile.get(approval_id) profile = self._approval_profile.get(approval_id)
worker_key = self._approval_worker_key.get(approval_id)
if not profile: if not profile:
raise KeyError(f"unknown approval request: {approval_id}") raise KeyError(f"unknown approval request: {approval_id}")
return self._forward(profile, req) return self._forward(profile, req, worker_key)
if action == "clarify_respond": if action == "clarify_respond":
clarify_id = str(req.get("clarify_id") or "").strip() clarify_id = str(req.get("clarify_id") or "").strip()
@@ -2705,9 +2741,10 @@ class BridgeBroker:
raise ValueError("clarify_id is required") raise ValueError("clarify_id is required")
with self._lock: with self._lock:
profile = self._clarify_profile.get(clarify_id) profile = self._clarify_profile.get(clarify_id)
worker_key = self._clarify_worker_key.get(clarify_id)
if not profile: if not profile:
raise KeyError(f"unknown clarify request: {clarify_id}") raise KeyError(f"unknown clarify request: {clarify_id}")
return self._forward(profile, req) return self._forward(profile, req, worker_key)
if action == "compression_respond": if action == "compression_respond":
request_id = str(req.get("request_id") or "").strip() request_id = str(req.get("request_id") or "").strip()
@@ -2715,20 +2752,27 @@ class BridgeBroker:
raise ValueError("request_id is required") raise ValueError("request_id is required")
with self._lock: with self._lock:
profile = self._compression_profile.get(request_id) profile = self._compression_profile.get(request_id)
worker_key = self._compression_worker_key.get(request_id)
if not profile: if not profile:
raise KeyError(f"unknown compression request: {request_id}") raise KeyError(f"unknown compression request: {request_id}")
return self._forward(profile, req) return self._forward(profile, req, worker_key)
if action == "destroy_all": if action == "destroy_all":
with self._lock: with self._lock:
workers = list(self._workers.values()) workers = list(self._workers.values())
self._workers.clear() self._workers.clear()
self._run_profile.clear() self._run_profile.clear()
self._run_worker_key.clear()
self._running_run_profile.clear() self._running_run_profile.clear()
self._running_run_worker_key.clear()
self._session_profile.clear() self._session_profile.clear()
self._session_worker_key.clear()
self._approval_profile.clear() self._approval_profile.clear()
self._approval_worker_key.clear()
self._clarify_profile.clear() self._clarify_profile.clear()
self._clarify_worker_key.clear()
self._compression_profile.clear() self._compression_profile.clear()
self._compression_worker_key.clear()
destroyed = 0 destroyed = 0
for worker in workers: for worker in workers:
try: try:
@@ -2744,40 +2788,56 @@ class BridgeBroker:
if action == "destroy_profile": if action == "destroy_profile":
profile = self._normalize_profile(req.get("profile")) profile = self._normalize_profile(req.get("profile"))
with self._lock: with self._lock:
worker = self._workers.pop(profile, None) workers = [
worker
for key, worker in list(self._workers.items())
if getattr(worker, "profile", key) == profile
]
for worker in workers:
self._workers.pop(worker.key, None)
self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile} self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile}
self._run_worker_key = {key: value for key, value in self._run_worker_key.items() if key in self._run_profile}
self._running_run_profile = {key: value for key, value in self._running_run_profile.items() if value != profile} self._running_run_profile = {key: value for key, value in self._running_run_profile.items() if value != profile}
self._running_run_worker_key = {key: value for key, value in self._running_run_worker_key.items() if key in self._running_run_profile}
self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile} self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile}
self._session_worker_key = {key: value for key, value in self._session_worker_key.items() if key in self._session_profile}
self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile} self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile}
self._approval_worker_key = {key: value for key, value in self._approval_worker_key.items() if key in self._approval_profile}
self._clarify_profile = {key: value for key, value in self._clarify_profile.items() if value != profile} self._clarify_profile = {key: value for key, value in self._clarify_profile.items() if value != profile}
self._clarify_worker_key = {key: value for key, value in self._clarify_worker_key.items() if key in self._clarify_profile}
self._compression_profile = {key: value for key, value in self._compression_profile.items() if value != profile} self._compression_profile = {key: value for key, value in self._compression_profile.items() if value != profile}
self._compression_worker_key = {key: value for key, value in self._compression_worker_key.items() if key in self._compression_profile}
if worker is None or not worker.running: if not workers:
if worker is not None:
worker.stop()
return {"profile": profile, "destroyed": 0} return {"profile": profile, "destroyed": 0}
try: destroyed = 0
resp = worker.request({"action": "destroy_all"}) for worker in workers:
destroyed = int(resp.get("destroyed") or 0) if not worker.running:
except Exception: worker.stop()
destroyed = 0 continue
finally: try:
worker.stop() resp = worker.request({"action": "destroy_all"})
destroyed += int(resp.get("destroyed") or 0)
except Exception:
pass
finally:
worker.stop()
return {"profile": profile, "destroyed": destroyed} return {"profile": profile, "destroyed": destroyed}
if action == "list": if action == "list":
sessions: list[Any] = [] sessions: list[Any] = []
with self._lock: with self._lock:
workers = list(self._workers.items()) workers = list(self._workers.items())
for profile, worker in workers: for key, worker in workers:
if not worker.running: if not worker.running:
continue continue
try: try:
resp = worker.request({"action": "list"}) resp = worker.request({"action": "list"})
for session in resp.get("sessions") or []: for session in resp.get("sessions") or []:
if isinstance(session, dict): if isinstance(session, dict):
session.setdefault("profile", profile) session.setdefault("profile", getattr(worker, "profile", key))
session.setdefault("worker_key", key)
sessions.append(session) sessions.append(session)
except Exception: except Exception:
pass pass
@@ -2826,12 +2886,12 @@ class BridgeBroker:
self._last_gc = now self._last_gc = now
with self._lock: with self._lock:
idle = [ idle = [
profile for profile, worker in self._workers.items() key for key, worker in self._workers.items()
if worker.running and now - worker.last_used_at > self.IDLE_TIMEOUT_SECONDS if worker.running and now - worker.last_used_at > self.IDLE_TIMEOUT_SECONDS
] ]
for profile in idle: for key in idle:
with self._lock: with self._lock:
worker = self._workers.pop(profile, None) worker = self._workers.pop(key, None)
if worker: if worker:
worker.stop() worker.stop()
@@ -195,7 +195,8 @@ export async function buildCompressedHistory(
emit: (event: string, payload: any) => void, emit: (event: string, payload: any) => void,
sessionMap: Map<string, SessionState>, sessionMap: Map<string, SessionState>,
modelContext: { model?: string | null; provider?: string | null } = {}, modelContext: { model?: string | null; provider?: string | null } = {},
contextTokenEstimator?: (messages: ChatMessage[]) => Promise<number | null | undefined>, contextTokenEstimator?: (messages: ChatMessage[], messageTokens: number) => Promise<number | null | undefined>,
currentInputTokens = 0,
): Promise<ChatMessage[]> { ): Promise<ChatMessage[]> {
try { try {
let history = await buildDbHistory(sessionId, { excludeLastUser: true }) let history = await buildDbHistory(sessionId, { excludeLastUser: true })
@@ -213,14 +214,18 @@ export async function buildCompressedHistory(
} }
const cState = getOrCreateSession(sessionMap, sessionId) const cState = getOrCreateSession(sessionMap, sessionId)
const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit) const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit)
const estimateFullContextTokens = async (messages: ChatMessage[], fallback: number) => { const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
? Math.floor(currentInputTokens)
: 0
const estimateLocalContextTokens = async (messages: ChatMessage[], messageTokens: number) => {
const localMessageTokens = Math.max(0, Math.floor(messageTokens))
try { try {
const estimate = await contextTokenEstimator?.(messages) const estimate = await contextTokenEstimator?.(messages, localMessageTokens)
if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate) if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate)
} catch (err) { } catch (err) {
logger.warn(err, '[context-compress] session=%s: full context token estimate failed; using message-only estimate', sessionId) logger.warn(err, '[context-compress] session=%s: fixed context token estimate failed; using message-only estimate', sessionId)
} }
return fallback return localMessageTokens
} }
const emitContextUsage = (contextTokens: number) => { const emitContextUsage = (contextTokens: number) => {
cState.contextTokens = contextTokens cState.contextTokens = contextTokens
@@ -236,10 +241,10 @@ export async function buildCompressedHistory(
let totalTokens = messageOnlyTotalTokens let totalTokens = messageOnlyTotalTokens
if (history.length === 0) { if (history.length === 0) {
totalTokens = await estimateFullContextTokens([], 0) totalTokens = await estimateLocalContextTokens([], Math.max(currentRunInputTokens, messageOnlyTotalTokens))
if (totalTokens > triggerTokens) { if (totalTokens > triggerTokens) {
throw new ContextWindowTooSmallError( throw new ContextWindowTooSmallError(
`Context window is too small: system prompt and tool schemas already use ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, or disable some tools.`, `Context window is too small: fixed prompt/tool overhead plus the current input uses ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, shorten the input, or disable some tools.`,
) )
} }
if (totalTokens > 0) emitContextUsage(totalTokens) if (totalTokens > 0) emitContextUsage(totalTokens)
@@ -254,13 +259,15 @@ export async function buildCompressedHistory(
sessionId, snapshot.lastMessageIndex, history.length) sessionId, snapshot.lastMessageIndex, history.length)
const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
const staleUsage = estimateUsageTokensFromMessages(staleHistory) const staleUsage = estimateUsageTokensFromMessages(staleHistory)
totalTokens = await estimateFullContextTokens(staleHistory, staleUsage.inputTokens + staleUsage.outputTokens) const staleMessageTokens = staleUsage.inputTokens + staleUsage.outputTokens
const staleRunMessageTokens = Math.max(staleMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
totalTokens = await estimateLocalContextTokens(staleHistory, staleRunMessageTokens)
emitContextUsage(totalTokens) emitContextUsage(totalTokens)
logger.info({ logger.info({
sessionId, sessionId,
profile, profile,
messages: staleHistory.length, messages: staleHistory.length,
messageOnlyTokens: staleUsage.inputTokens + staleUsage.outputTokens, messageOnlyTokens: staleRunMessageTokens,
fullContextTokens: totalTokens, fullContextTokens: totalTokens,
triggerTokens, triggerTokens,
decision: totalTokens > triggerTokens ? 'compress' : 'skip', decision: totalTokens > triggerTokens ? 'compress' : 'skip',
@@ -272,13 +279,15 @@ export async function buildCompressedHistory(
const newMessages = history.slice(snapshot.lastMessageIndex + 1) const newMessages = history.slice(snapshot.lastMessageIndex + 1)
const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory) const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory)
totalTokens = await estimateFullContextTokens(snapshotHistory, snapshotUsage.inputTokens + snapshotUsage.outputTokens) const snapshotMessageTokens = snapshotUsage.inputTokens + snapshotUsage.outputTokens
const snapshotRunMessageTokens = Math.max(snapshotMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
totalTokens = await estimateLocalContextTokens(snapshotHistory, snapshotRunMessageTokens)
emitContextUsage(totalTokens) emitContextUsage(totalTokens)
logger.info({ logger.info({
sessionId, sessionId,
profile, profile,
messages: snapshotHistory.length, messages: snapshotHistory.length,
messageOnlyTokens: snapshotUsage.inputTokens + snapshotUsage.outputTokens, messageOnlyTokens: snapshotRunMessageTokens,
fullContextTokens: totalTokens, fullContextTokens: totalTokens,
triggerTokens, triggerTokens,
decision: totalTokens > triggerTokens ? 'compress' : 'skip', decision: totalTokens > triggerTokens ? 'compress' : 'skip',
@@ -289,22 +298,25 @@ export async function buildCompressedHistory(
if (totalTokens <= triggerTokens) { if (totalTokens <= triggerTokens) {
history = snapshotHistory history = snapshotHistory
} else { } else {
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor) history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
} }
} else if (snapshot && staleSnapshot) { } else if (snapshot && staleSnapshot) {
if (totalTokens <= triggerTokens) { if (totalTokens <= triggerTokens) {
history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
} else { } else {
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor) history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
} }
} else { } else {
totalTokens = await estimateFullContextTokens(history, totalTokens) const historyUsage = estimateUsageTokensFromMessages(history)
const historyMessageTokens = historyUsage.inputTokens + historyUsage.outputTokens
const runMessageTokens = Math.max(historyMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
totalTokens = await estimateLocalContextTokens(history, runMessageTokens)
emitContextUsage(totalTokens) emitContextUsage(totalTokens)
logger.info({ logger.info({
sessionId, sessionId,
profile, profile,
messages: history.length, messages: history.length,
messageOnlyTokens: messageOnlyTotalTokens, messageOnlyTokens: runMessageTokens,
fullContextTokens: totalTokens, fullContextTokens: totalTokens,
triggerTokens, triggerTokens,
decision: totalTokens > triggerTokens ? 'compress' : 'skip', decision: totalTokens > triggerTokens ? 'compress' : 'skip',
@@ -318,7 +330,7 @@ export async function buildCompressedHistory(
if (totalTokens <= triggerTokens) { if (totalTokens <= triggerTokens) {
logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens) logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens)
} else { } else {
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor) history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
} }
} }
@@ -342,8 +354,12 @@ export async function compressHistory(
sessionMap: Map<string, SessionState>, sessionMap: Map<string, SessionState>,
modelContext: { model?: string | null; provider?: string | null } = {}, modelContext: { model?: string | null; provider?: string | null } = {},
compressionConfig?: Partial<CompressorConfig>, compressionConfig?: Partial<CompressorConfig>,
currentInputTokens = 0,
): Promise<ChatMessage[]> { ): Promise<ChatMessage[]> {
const msgCount = newMessagesOnly ? newMessagesOnly.length : history.length const msgCount = newMessagesOnly ? newMessagesOnly.length : history.length
const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
? Math.floor(currentInputTokens)
: 0
pushState(sessionMap, sessionId, 'compression.started', { pushState(sessionMap, sessionId, 'compression.started', {
event: 'compression.started', message_count: msgCount, token_count: totalTokens, event: 'compression.started', message_count: msgCount, token_count: totalTokens,
}) })
@@ -353,14 +369,22 @@ export async function compressHistory(
try { try {
const session = getSession(sessionId) const session = getSession(sessionId)
const summarizerProfile = session?.profile || 'default'
const compressor = new ChatContextCompressor({ config: compressionConfig }) const compressor = new ChatContextCompressor({ config: compressionConfig })
const result = await compressor.compress(history, upstream, apiKey, sessionId, { const result = await compressor.compress(history, upstream, apiKey, sessionId, {
profile: session?.profile, profile: summarizerProfile,
model: modelContext.model || session?.model, model: modelContext.model || session?.model,
provider: modelContext.provider || session?.provider, provider: modelContext.provider || session?.provider,
workerKey: `${summarizerProfile}:compression:${sessionId}`,
}) })
const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit) const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit)
const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens
const resultUsage = estimateUsageTokensFromMessages(result.messages)
const resultMessageTokens = resultUsage.inputTokens + resultUsage.outputTokens
const compressedRunMessageTokens = Math.max(
compressedAfterTokens,
resultMessageTokens + currentRunInputTokens,
)
const compressedMeta: any = { const compressedMeta: any = {
event: 'compression.completed' as const, event: 'compression.completed' as const,
compressed: result.meta.compressed, compressed: result.meta.compressed,
@@ -368,15 +392,15 @@ export async function compressHistory(
totalMessages: result.meta.totalMessages, totalMessages: result.meta.totalMessages,
resultMessages: result.messages.length, resultMessages: result.messages.length,
beforeTokens: totalTokens, beforeTokens: totalTokens,
afterTokens: compressedAfterTokens, afterTokens: compressedRunMessageTokens,
summaryTokens: result.meta.summaryTokenEstimate, summaryTokens: result.meta.summaryTokenEstimate,
verbatimCount: result.meta.verbatimCount, verbatimCount: result.meta.verbatimCount,
compressedStartIndex: result.meta.compressedStartIndex, compressedStartIndex: result.meta.compressedStartIndex,
} }
replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta) replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta)
logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)', logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)',
sessionId, result.messages.length, compressedAfterTokens, totalTokens) sessionId, result.messages.length, compressedRunMessageTokens, totalTokens)
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedAfterTokens, afterTokens) const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedRunMessageTokens, afterTokens)
if (compressedContextTokens != null) { if (compressedContextTokens != null) {
compressedMeta.contextTokens = compressedContextTokens compressedMeta.contextTokens = compressedContextTokens
} }
@@ -403,6 +427,7 @@ export async function compressHistory(
resultMessages: msgCount, resultMessages: msgCount,
beforeTokens: totalTokens, beforeTokens: totalTokens,
afterTokens: totalTokens, afterTokens: totalTokens,
contextTokens: totalTokens,
summaryTokens: 0, summaryTokens: 0,
verbatimCount: msgCount, verbatimCount: msgCount,
compressedStartIndex: -1, compressedStartIndex: -1,
@@ -458,10 +483,12 @@ export async function forceCompressBridgeHistory(
}, '[chat-run-socket] bridge forced compression started') }, '[chat-run-socket] bridge forced compression started')
const compressor = new ChatContextCompressor({ config: compressionConfig.compressor }) const compressor = new ChatContextCompressor({ config: compressionConfig.compressor })
const summarizerProfile = session?.profile || profile || 'default'
const result = await compressor.compress(history, upstream, apiKey, sessionId, { const result = await compressor.compress(history, upstream, apiKey, sessionId, {
profile: session?.profile || profile, profile: summarizerProfile,
model: session?.model, model: session?.model,
provider: session?.provider, provider: session?.provider,
workerKey: `${summarizerProfile}:compression:${sessionId}`,
}) })
const compressedMessages = result.messages.map(m => { const compressedMessages = result.messages.map(m => {
const msg: any = { role: m.role, content: m.content } const msg: any = { role: m.role, content: m.content }
@@ -18,7 +18,7 @@ import { convertHistoryFormat } from './message-format'
import { readSseFrames } from './sse-utils' import { readSseFrames } from './sse-utils'
import { extractResponseText } from './response-utils' import { extractResponseText } from './response-utils'
import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream' import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream'
import { buildCompressedHistory, getOrCreateSession } from './compression' import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, getOrCreateSession } from './compression'
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage' import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
import { handleMessage } from './message-format' import { handleMessage } from './message-format'
import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor' import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
@@ -37,6 +37,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
let inputTokens: number let inputTokens: number
let outputTokens: number let outputTokens: number
let contextTokens: number | undefined
const snapshot = getCompressionSnapshot(sid) const snapshot = getCompressionSnapshot(sid)
if (snapshot && snapshot.lastMessageIndex >= 0 && snapshot.lastMessageIndex < messages.length) { if (snapshot && snapshot.lastMessageIndex >= 0 && snapshot.lastMessageIndex < messages.length) {
const newMessages = messages.slice(snapshot.lastMessageIndex + 1) const newMessages = messages.slice(snapshot.lastMessageIndex + 1)
@@ -49,6 +50,20 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
inputTokens = usage.inputTokens inputTokens = usage.inputTokens
outputTokens = usage.outputTokens outputTokens = usage.outputTokens
} }
try {
const session = getSession(sid)
const dbHistory = await buildDbHistory(sid, { excludeLastUser: false })
const snapshotHistory = await buildSnapshotAwareHistory(
sid,
session?.profile || 'default',
dbHistory,
{ model: session?.model, provider: session?.provider },
)
const contextUsage = estimateUsageTokensFromMessages(snapshotHistory)
contextTokens = contextUsage.inputTokens + contextUsage.outputTokens
} catch (err) {
logger.warn(err, '[chat-run-socket] failed to calculate snapshot-aware context tokens for session %s', sid)
}
logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length) logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length)
return { return {
@@ -57,6 +72,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
events: [], events: [],
inputTokens, inputTokens,
outputTokens, outputTokens,
contextTokens,
queue: [], queue: [],
} }
} catch (err) { } catch (err) {
@@ -16,7 +16,6 @@ import {
contextTokensWithCachedOverhead, contextTokensWithCachedOverhead,
estimateUsageTokensFromMessages, estimateUsageTokensFromMessages,
getCachedBridgeContextOverhead, getCachedBridgeContextOverhead,
updateContextTokenUsage,
updateMessageContextTokenUsage, updateMessageContextTokenUsage,
} from './usage' } from './usage'
import { import {
@@ -123,6 +122,66 @@ function cacheBridgeContext(state: SessionState, data: Record<string, unknown> |
} }
} }
function bridgeContextMatches(
state: SessionState,
expected: { profile: string; model?: string | null; provider?: string | null },
): boolean {
const context = state.bridgeContext
if (!context) return false
if (context.profile && context.profile !== expected.profile) return false
if (expected.model && context.model && context.model !== expected.model) return false
if (expected.provider && context.provider && context.provider !== expected.provider) return false
return true
}
async function ensureBridgeFixedContext(args: {
sessionId: string
profile: string
model?: string | null
provider?: string | null
instructions: string
state: SessionState
bridge: AgentBridgeClient
refresh?: boolean
}): Promise<number | undefined> {
const cached = bridgeContextMatches(args.state, args)
? getCachedBridgeContextOverhead(args.state)
: undefined
if (!args.refresh && cached != null) return cached
try {
const estimate = await args.bridge.contextEstimate(
args.sessionId,
[],
args.instructions,
args.profile,
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
)
cacheBridgeContext(args.state, estimate)
const fixedContextTokens = getCachedBridgeContextOverhead(args.state)
bridgeLogger.info({
sessionId: args.sessionId,
profile: args.profile,
model: args.model,
provider: args.provider,
toolCount: estimate.tool_count,
systemPromptChars: estimate.system_prompt_chars,
fixedContextTokens,
}, '[chat-run-socket] fixed context estimate')
return fixedContextTokens
} catch (err) {
bridgeLogger.warn({
err: err instanceof Error ? { message: err.message, name: err.name } : err,
sessionId: args.sessionId,
profile: args.profile,
model: args.model,
provider: args.provider,
cachedFixedContextTokens: cached,
}, '[chat-run-socket] fixed context estimate failed')
return cached
}
}
export async function handleBridgeRun( export async function handleBridgeRun(
nsp: ReturnType<Server['of']>, nsp: ReturnType<Server['of']>,
socket: Socket, socket: Socket,
@@ -195,6 +254,9 @@ export async function handleBridgeRun(
const displayInput = data.display_input === undefined ? input : data.display_input const displayInput = data.display_input === undefined ? input : data.display_input
const inputStr = displayInput == null ? '' : contentBlocksToString(displayInput) const inputStr = displayInput == null ? '' : contentBlocksToString(displayInput)
const actualInputStr = contentBlocksToString(input)
const currentInputUsage = estimateUsageTokensFromMessages([{ role: 'user', content: actualInputStr }])
const currentInputTokens = currentInputUsage.inputTokens
const shouldPersistUserMessage = !skipUserMessage && displayInput !== null const shouldPersistUserMessage = !skipUserMessage && displayInput !== null
const displayRole = data.display_role === 'command' ? 'command' : 'user' const displayRole = data.display_role === 'command' ? 'command' : 'user'
let messageId: number | string | undefined let messageId: number | string | undefined
@@ -257,33 +319,32 @@ export async function handleBridgeRun(
emit, emit,
sessionMap, sessionMap,
{ model: resolvedModel, provider: resolvedProvider }, { model: resolvedModel, provider: resolvedProvider },
async (messages) => { async (_messages, localMessageTokens) => {
const cachedOverhead = getCachedBridgeContextOverhead(state) const fixedContextTokens = await ensureBridgeFixedContext({
if (cachedOverhead != null) { sessionId: session_id,
const messageUsage = estimateUsageTokensFromMessages(messages)
return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens
}
const estimate = await bridge.contextEstimate(
session_id,
messages,
fullInstructions,
profile, profile,
{ model: resolvedModel, provider: resolvedProvider }, model: resolvedModel,
) provider: resolvedProvider,
cacheBridgeContext(state, estimate) instructions: fullInstructions,
state,
bridge,
refresh: true,
})
const contextTokens = fixedContextTokens == null
? localMessageTokens
: fixedContextTokens + localMessageTokens
bridgeLogger.info({ bridgeLogger.info({
sessionId: session_id, sessionId: session_id,
profile, profile,
model: resolvedModel, model: resolvedModel,
provider: resolvedProvider, provider: resolvedProvider,
messages: estimate.message_count, fixedContextTokens,
toolCount: estimate.tool_count, messageTokens: localMessageTokens,
systemPromptChars: estimate.system_prompt_chars, contextTokens,
fixedContextTokens: estimate.fixed_context_tokens, }, '[chat-run-socket] local context estimate')
fullContextTokens: estimate.token_count, return contextTokens
}, '[chat-run-socket] full context estimate')
return estimate.token_count
}, },
currentInputTokens,
) )
const bridgeHistory = history const bridgeHistory = history
@@ -349,6 +410,8 @@ export async function handleBridgeRun(
dequeueNextQueuedRun, dequeueNextQueuedRun,
fullInstructions, fullInstructions,
{ model: resolvedModel, provider: resolvedProvider }, { model: resolvedModel, provider: resolvedProvider },
currentInputTokens,
shouldPersistUserMessage && displayRole === 'user',
data.model_groups, data.model_groups,
) )
if (chunk.done) break if (chunk.done) break
@@ -417,61 +480,68 @@ async function refreshFinalContextUsage(args: {
) )
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory) const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
if (getCachedBridgeContextOverhead(args.state) != null) { await ensureBridgeFixedContext({
const contextTokens = updateMessageContextTokenUsage( sessionId: args.sessionId,
args.sessionId, profile: args.profile,
args.state, model: args.model,
args.emit, provider: args.provider,
finalMessageTokens, instructions: args.instructions,
args.usage, state: args.state,
) bridge: args.bridge,
bridgeLogger.info({ })
sessionId: args.sessionId, const contextTokens = updateMessageContextTokenUsage(
profile: args.profile,
model: args.model,
provider: args.provider,
messages: finalHistory.length,
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
messageTokens: finalMessageTokens,
fullContextTokens: contextTokens,
}, '[chat-run-socket] final cached context estimate')
return contextTokens
}
const estimate = await args.bridge.contextEstimate(
args.sessionId, args.sessionId,
finalHistory, args.state,
args.instructions, args.emit,
args.profile, finalMessageTokens,
{ model: args.model ?? undefined, provider: args.provider ?? undefined }, args.usage,
) )
cacheBridgeContext(args.state, estimate)
const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0
? Math.floor(estimate.token_count)
: undefined
if (contextTokens == null) return args.state.contextTokens
updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage)
bridgeLogger.info({ bridgeLogger.info({
sessionId: args.sessionId, sessionId: args.sessionId,
profile: args.profile, profile: args.profile,
model: args.model, model: args.model,
provider: args.provider, provider: args.provider,
messages: estimate.message_count, messages: finalHistory.length,
toolCount: estimate.tool_count, fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
systemPromptChars: estimate.system_prompt_chars, messageTokens: finalMessageTokens,
fullContextTokens: contextTokens, contextTokens,
}, '[chat-run-socket] final full context estimate') }, '[chat-run-socket] final local context estimate')
return contextTokens return contextTokens
} catch (err) { } catch (err) {
bridgeLogger.warn({ bridgeLogger.warn({
err: err instanceof Error ? { message: err.message, name: err.name } : err, err: err instanceof Error ? { message: err.message, name: err.name } : err,
sessionId: args.sessionId, sessionId: args.sessionId,
profile: args.profile, profile: args.profile,
}, '[chat-run-socket] final full context estimate failed') }, '[chat-run-socket] final local context estimate failed')
return args.state.contextTokens return args.state.contextTokens
} }
} }
async function estimateSnapshotAwareMessageTokens(args: {
sessionId: string
profile: string
model?: string | null
provider?: string | null
currentInputTokens?: number
currentInputIncludedInDb?: boolean
}): Promise<{ messageTokens: number; messages: number }> {
const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
const snapshotHistory = await buildSnapshotAwareHistory(
args.sessionId,
args.profile,
dbHistory,
{ model: args.model, provider: args.provider },
)
const usage = estimateUsageTokensFromMessages(snapshotHistory)
const extraInputTokens = args.currentInputIncludedInDb
? 0
: finiteToken(args.currentInputTokens) ?? 0
return {
messageTokens: usage.inputTokens + usage.outputTokens + extraInputTokens,
messages: snapshotHistory.length,
}
}
async function applyBridgeChunkAsync( async function applyBridgeChunkAsync(
nsp: ReturnType<Server['of']>, nsp: ReturnType<Server['of']>,
socket: Socket, socket: Socket,
@@ -486,6 +556,8 @@ async function applyBridgeChunkAsync(
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void, dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
instructions: string, instructions: string,
modelContext: { model?: string | null; provider?: string | null }, modelContext: { model?: string | null; provider?: string | null },
currentInputTokens = 0,
currentInputIncludedInDb = true,
modelGroups?: RunModelGroup[], modelGroups?: RunModelGroup[],
): Promise<void> { ): Promise<void> {
if (state.activeRunMarker !== runMarker) { if (state.activeRunMarker !== runMarker) {
@@ -505,11 +577,19 @@ async function applyBridgeChunkAsync(
if (evType === 'bridge.context.ready') { if (evType === 'bridge.context.ready') {
cacheBridgeContext(state, ev) cacheBridgeContext(state, ev)
const usage = await calcAndUpdateUsage(sessionId, state, emit) const usage = await calcAndUpdateUsage(sessionId, state, emit)
const snapshotAware = await estimateSnapshotAwareMessageTokens({
sessionId,
profile,
model: modelContext.model,
provider: modelContext.provider,
currentInputTokens,
currentInputIncludedInDb,
})
updateMessageContextTokenUsage( updateMessageContextTokenUsage(
sessionId, sessionId,
state, state,
emit, emit,
usage.inputTokens + usage.outputTokens, snapshotAware.messageTokens,
usage, usage,
) )
} else if (evType === 'tool.started') { } else if (evType === 'tool.started') {
@@ -646,17 +726,22 @@ async function applyBridgeChunkAsync(
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true }) const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory) const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens
const tokenCount = typeof ev.approx_tokens === 'number' && Number.isFinite(ev.approx_tokens) && ev.approx_tokens > 0 const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
? ev.approx_tokens ? Math.floor(currentInputTokens)
: messageOnlyTokens : 0
const runMessageTokens = messageOnlyTokens + runInputTokens
const tokenCount = contextTokensWithCachedOverhead(state, runMessageTokens)
bridgeLogger.info({ bridgeLogger.info({
sessionId, sessionId,
profile, profile,
bridgeMessages: ev.message_count, bridgeMessages: ev.message_count,
dbMessages: bridgeHistory.length, dbMessages: bridgeHistory.length,
messageOnlyTokens, messageOnlyTokens,
fullContextTokens: tokenCount, currentInputTokens: runInputTokens,
source: typeof ev.approx_tokens === 'number' ? 'bridge' : 'message-only-fallback', fixedContextTokens: state.bridgeContext?.fixedContextTokens,
contextTokens: tokenCount,
bridgeApproxTokens: ev.approx_tokens,
source: 'local',
}, '[chat-run-socket] bridge compression token estimate') }, '[chat-run-socket] bridge compression token estimate')
const payload = { const payload = {
event: 'compression.started', event: 'compression.started',
@@ -674,7 +759,7 @@ async function applyBridgeChunkAsync(
sessionId, sessionId,
profile, profile,
ev.messages as ChatMessage[], ev.messages as ChatMessage[],
typeof ev.approx_tokens === 'number' ? ev.approx_tokens : undefined, tokenCount,
) )
state.bridgeCompressionResults = state.bridgeCompressionResults || {} state.bridgeCompressionResults = state.bridgeCompressionResults || {}
state.bridgeCompressionResults[String(ev.request_id)] = compressed state.bridgeCompressionResults[String(ev.request_id)] = compressed
@@ -689,11 +774,16 @@ async function applyBridgeChunkAsync(
const compressionResult = ev.request_id const compressionResult = ev.request_id
? state.bridgeCompressionResults?.[String(ev.request_id)] ? state.bridgeCompressionResults?.[String(ev.request_id)]
: undefined : undefined
const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens)
const messageAfterTokens = finiteToken(compressionResult?.afterTokens) const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
? contextTokensWithCachedOverhead(state, messageAfterTokens) ? Math.floor(currentInputTokens)
: bridgeAfterContextTokens ?? messageAfterTokens : 0
const messageAfterTokensWithInput = messageAfterTokens != null
? messageAfterTokens + runInputTokens
: undefined
const afterContextTokens = messageAfterTokensWithInput != null
? contextTokensWithCachedOverhead(state, messageAfterTokensWithInput)
: undefined
const payload = { const payload = {
event: 'compression.completed', event: 'compression.completed',
run_id: chunk.run_id, run_id: chunk.run_id,
@@ -703,7 +793,7 @@ async function applyBridgeChunkAsync(
totalMessages: compressionResult?.beforeMessages ?? ev.message_count, totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
resultMessages: compressionResult?.resultMessages ?? ev.result_messages, resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens, beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
afterTokens: messageAfterTokens ?? bridgeAfterContextTokens, afterTokens: messageAfterTokensWithInput,
contextTokens: afterContextTokens, contextTokens: afterContextTokens,
summaryTokens: compressionResult?.summaryTokens, summaryTokens: compressionResult?.summaryTokens,
verbatimCount: compressionResult?.verbatimCount, verbatimCount: compressionResult?.verbatimCount,
@@ -716,10 +806,8 @@ async function applyBridgeChunkAsync(
replaceState(sessionMap, sessionId, 'compression.completed', payload) replaceState(sessionMap, sessionId, 'compression.completed', payload)
emit('compression.completed', payload) emit('compression.completed', payload)
const usage = await calcAndUpdateUsage(sessionId, state, emit) const usage = await calcAndUpdateUsage(sessionId, state, emit)
if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) { if (messageAfterTokensWithInput != null) {
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage) updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokensWithInput, usage)
} else {
updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage)
} }
} else if (evType === 'bridge.compression.failed') { } else if (evType === 'bridge.compression.failed') {
const payload = { const payload = {
@@ -0,0 +1,96 @@
// @vitest-environment jsdom
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createPinia, setActivePinia } from 'pinia'
const chatApi = vi.hoisted(() => ({
resumeSession: vi.fn(),
registerSessionHandlers: vi.fn(),
unregisterSessionHandlers: vi.fn(),
}))
vi.mock('@/api/hermes/chat', () => ({
startRunViaSocket: vi.fn(),
resumeSession: chatApi.resumeSession,
registerSessionHandlers: chatApi.registerSessionHandlers,
unregisterSessionHandlers: chatApi.unregisterSessionHandlers,
getChatRunSocket: vi.fn(() => ({ emit: vi.fn() })),
respondToolApproval: vi.fn(),
respondClarify: vi.fn(),
onPeerUserMessage: vi.fn(() => vi.fn()),
onSessionCommand: vi.fn(() => vi.fn()),
}))
vi.mock('@/api/client', () => ({
getActiveProfileName: () => 'default',
}))
vi.mock('@/api/hermes/sessions', () => ({
deleteSession: vi.fn(),
fetchSession: vi.fn(),
fetchSessions: vi.fn(),
setSessionModel: vi.fn(),
}))
vi.mock('@/api/hermes/download', () => ({
getDownloadUrl: (_path: string, name: string) => `/download/${name}`,
}))
vi.mock('@/utils/completion-sound', () => ({
primeCompletionSound: vi.fn(),
playCompletionSound: vi.fn(),
}))
import { useChatStore, type Session } from '@/stores/hermes/chat'
function makeSession(id: string): Session {
return {
id,
title: id,
messages: [],
createdAt: Date.now(),
updatedAt: Date.now(),
}
}
describe('chat store compression state', () => {
beforeEach(() => {
vi.resetAllMocks()
setActivePinia(createPinia())
chatApi.resumeSession.mockImplementation((sessionId: string, onResumed: (data: any) => void) => {
onResumed({
session_id: sessionId,
messages: [],
isWorking: sessionId === 'session-1',
events: [],
})
return {} as any
})
})
it('does not show a background session compression indicator in the active session', async () => {
const store = useChatStore()
store.sessions = [makeSession('session-1'), makeSession('session-2')]
await store.switchSession('session-1')
const handlers = chatApi.registerSessionHandlers.mock.calls.find(call => call[0] === 'session-1')?.[1]
expect(handlers).toBeTruthy()
await store.switchSession('session-2')
handlers.onCompressionStarted({
event: 'compression.started',
session_id: 'session-1',
message_count: 6,
token_count: 1234,
})
expect(store.activeSessionId).toBe('session-2')
expect(store.compressionState).toBeNull()
await store.switchSession('session-1')
expect(store.compressionState).toEqual(expect.objectContaining({
compressing: true,
messageCount: 6,
beforeTokens: 1234,
}))
})
})
@@ -292,9 +292,11 @@ except RuntimeError as exc:
assert "already running" in str(exc) assert "already running" in str(exc)
class FakeWorker: class FakeWorker:
def __init__(self, destroyed): def __init__(self, destroyed, profile="default", key="default"):
self.running = True self.running = True
self.destroyed = destroyed self.destroyed = destroyed
self.profile = profile
self.key = key
self.requests = [] self.requests = []
self.stopped = False self.stopped = False
@@ -310,28 +312,41 @@ broker = bridge.BridgeBroker("ipc:///tmp/unused.sock")
profile_worker = FakeWorker(2) profile_worker = FakeWorker(2)
broker._workers["default"] = profile_worker broker._workers["default"] = profile_worker
broker._run_profile["run-session-a"] = "default" broker._run_profile["run-session-a"] = "default"
broker._run_worker_key["run-session-a"] = "default"
broker._running_run_profile["run-session-a"] = "default" broker._running_run_profile["run-session-a"] = "default"
broker._running_run_worker_key["run-session-a"] = "default"
broker._session_profile["session-a"] = "default" broker._session_profile["session-a"] = "default"
broker._session_worker_key["session-a"] = "default"
broker._approval_profile["approval-a"] = "default" broker._approval_profile["approval-a"] = "default"
broker._approval_worker_key["approval-a"] = "default"
broker._compression_profile["compression-a"] = "default" broker._compression_profile["compression-a"] = "default"
broker._compression_worker_key["compression-a"] = "default"
destroy_profile_result = broker.handle({"action": "destroy_profile", "profile": "default"}) destroy_profile_result = broker.handle({"action": "destroy_profile", "profile": "default"})
assert destroy_profile_result == {"profile": "default", "destroyed": 2} assert destroy_profile_result == {"profile": "default", "destroyed": 2}
assert profile_worker.stopped assert profile_worker.stopped
assert "default" not in broker._workers assert "default" not in broker._workers
assert broker._run_profile == {} assert broker._run_profile == {}
assert broker._run_worker_key == {}
assert broker._running_run_profile == {} assert broker._running_run_profile == {}
assert broker._running_run_worker_key == {}
assert broker._session_profile == {} assert broker._session_profile == {}
assert broker._session_worker_key == {}
assert broker._approval_profile == {} assert broker._approval_profile == {}
assert broker._approval_worker_key == {}
assert broker._compression_profile == {} assert broker._compression_profile == {}
assert broker._compression_worker_key == {}
worker_a = FakeWorker(1) worker_a = FakeWorker(1, "default", "a")
worker_b = FakeWorker(3) worker_b = FakeWorker(3, "work", "b")
broker._workers["a"] = worker_a broker._workers["a"] = worker_a
broker._workers["b"] = worker_b broker._workers["b"] = worker_b
broker._run_profile["run-a"] = "a" broker._run_profile["run-a"] = "default"
broker._running_run_profile["run-a"] = "a" broker._run_worker_key["run-a"] = "a"
broker._session_profile["session-b"] = "b" broker._running_run_profile["run-a"] = "default"
broker._running_run_worker_key["run-a"] = "a"
broker._session_profile["session-b"] = "work"
broker._session_worker_key["session-b"] = "b"
destroy_all_result = broker.handle({"action": "destroy_all"}) destroy_all_result = broker.handle({"action": "destroy_all"})
assert destroy_all_result == {"destroyed": 4} assert destroy_all_result == {"destroyed": 4}
@@ -339,8 +354,11 @@ assert worker_a.stopped
assert worker_b.stopped assert worker_b.stopped
assert broker._workers == {} assert broker._workers == {}
assert broker._run_profile == {} assert broker._run_profile == {}
assert broker._run_worker_key == {}
assert broker._running_run_profile == {} assert broker._running_run_profile == {}
assert broker._running_run_worker_key == {}
assert broker._session_profile == {} assert broker._session_profile == {}
assert broker._session_worker_key == {}
`) `)
}) })
@@ -372,6 +390,69 @@ assert resp["running_sessions_by_profile"] == {"default": 1}
`) `)
}) })
it('routes worker-keyed broker requests without stopping the worker on session destroy', () => {
runPython(String.raw`
${harness}
class RoutedWorker:
running = True
pid = 12345
endpoint = "ipc:///tmp/worker.sock"
last_used_at = 12.5
def __init__(self, profile, key):
self.profile = profile
self.key = key
self.requests = []
self.stopped = False
def request(self, req):
self.requests.append(req)
action = req.get("action")
if action == "chat":
return {"ok": True, "run_id": "run-compress", "session_id": req["session_id"], "status": "running"}
if action == "get_output":
return {"ok": True, "run_id": req["run_id"], "session_id": "compress-temp", "status": "complete", "done": True}
if action == "destroy":
return {"ok": True, "session_id": req["session_id"], "destroyed": True}
raise AssertionError(f"unexpected action: {action}")
def stop(self):
self.stopped = True
broker = bridge.BridgeBroker("ipc:///tmp/unused.sock")
worker = RoutedWorker("default", "default:compression:session-a")
broker._workers[worker.key] = worker
chat_resp = broker.handle({
"action": "chat",
"session_id": "compress-temp",
"profile": "default",
"worker_key": worker.key,
"message": "summarize",
})
assert chat_resp["run_id"] == "run-compress"
assert worker.requests[-1]["profile"] == "default"
assert "worker_key" not in worker.requests[-1]
broker.handle({"action": "get_output", "run_id": "run-compress"})
assert worker.requests[-1]["action"] == "get_output"
destroy_resp = broker.handle({
"action": "destroy",
"session_id": "compress-temp",
"profile": "default",
"worker_key": worker.key,
})
assert destroy_resp["destroyed"] is True
assert worker.requests[-1]["action"] == "destroy"
assert not worker.stopped
assert worker.key in broker._workers
assert "compress-temp" not in broker._session_profile
assert "compress-temp" not in broker._session_worker_key
`)
})
it('restores approval env and clears handlers when a run fails', () => { it('restores approval env and clears handlers when a run fails', () => {
runPython(String.raw` runPython(String.raw`
${harness} ${harness}
@@ -480,7 +561,7 @@ original_getpid = bridge.os.getpid
try: try:
bridge.subprocess.Popen = fake_popen bridge.subprocess.Popen = fake_popen
bridge.os.getpid = lambda: 4242 bridge.os.getpid = lambda: 4242
proc_worker = bridge.WorkerProcess("default", "ipc:///tmp/worker.sock", "/agent", "/home") proc_worker = bridge.WorkerProcess("default:compression:session-a", "default", "ipc:///tmp/worker.sock", "/agent", "/home")
proc_worker._pipe_stderr = lambda: None proc_worker._pipe_stderr = lambda: None
proc_worker._wait_ready = lambda: None proc_worker._wait_ready = lambda: None
proc_worker.start() proc_worker.start()
+36
View File
@@ -153,6 +153,42 @@ describe('ChatContextCompressor', () => {
expect(saveCompressionSnapshotMock).toHaveBeenCalledWith('s1', 'compressed summary', 6, 10) expect(saveCompressionSnapshotMock).toHaveBeenCalledWith('s1', 'compressed summary', 6, 10)
}) })
it('routes summarization through the provided worker key and destroys only the temporary agent session', async () => {
const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor')
const compressor = new ChatContextCompressor({
config: { headMessageCount: 0, tailMessageCount: 1, summaryBudget: 1000 },
})
const messages = [
{ role: 'user', content: 'old context' },
{ role: 'assistant', content: 'old response' },
{ role: 'user', content: 'tail' },
]
getCompressionSnapshotMock.mockReturnValue(null)
bridgeRequestMock.mockResolvedValue({
status: 'completed',
result: { final_response: 'compressed summary' },
})
await compressor.compress(messages, 'http://upstream', undefined, 's1', {
profile: 'default',
workerKey: 'default:compression:s1',
})
expect(bridgeRequestMock).toHaveBeenCalledWith(expect.objectContaining({
action: 'chat',
profile: 'default',
worker_key: 'default:compression:s1',
wait: true,
}), expect.any(Object))
const compressSessionId = bridgeRequestMock.mock.calls[0][0].session_id
expect(String(compressSessionId)).toMatch(/^compress_/)
expect(bridgeDestroyMock).toHaveBeenCalledWith(
compressSessionId,
'default',
'default:compression:s1',
)
})
it('does not pre-prune tool results before sending them to the summarizer', async () => { it('does not pre-prune tool results before sending them to the summarizer', async () => {
const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor') const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor')
const compressor = new ChatContextCompressor({ const compressor = new ChatContextCompressor({
@@ -127,8 +127,18 @@ describe('bridge run final context usage', () => {
buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history) buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history)
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 }) calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 })
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 }) estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 })
getCachedBridgeContextOverheadMock.mockReturnValue(undefined) getCachedBridgeContextOverheadMock.mockImplementation((state: any) => {
contextTokensWithCachedOverheadMock.mockImplementation((_state: any, messageTokens: number) => messageTokens) const fixed = state?.bridgeContext?.fixedContextTokens
return typeof fixed === 'number' ? fixed : undefined
})
contextTokensWithCachedOverheadMock.mockImplementation((state: any, messageTokens: number) => {
const fixed = state?.bridgeContext?.fixedContextTokens
return typeof fixed === 'number' ? fixed + messageTokens : messageTokens
})
updateMessageContextTokenUsageMock.mockImplementation((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => {
const contextTokens = contextTokensWithCachedOverheadMock(state, messageTokens)
return updateContextTokenUsageMock(sid, state, emit, contextTokens, usage)
})
}) })
it('refreshes full context tokens when a bridge run completes', async () => { it('refreshes full context tokens when a bridge run completes', async () => {
@@ -141,6 +151,7 @@ describe('bridge run final context usage', () => {
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }), chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
contextEstimate: vi.fn().mockResolvedValue({ contextEstimate: vi.fn().mockResolvedValue({
token_count: 12345, token_count: 12345,
fixed_context_tokens: 12327,
message_count: 2, message_count: 2,
tool_count: 4, tool_count: 4,
system_prompt_chars: 13, system_prompt_chars: 13,
@@ -165,10 +176,7 @@ describe('bridge run final context usage', () => {
expect(bridge.contextEstimate).toHaveBeenCalledWith( expect(bridge.contextEstimate).toHaveBeenCalledWith(
'session-1', 'session-1',
[ [],
{ role: 'user', content: 'hello' },
{ role: 'assistant', content: 'done' },
],
expect.stringContaining('[Current Hermes profile: default]'), expect.stringContaining('[Current Hermes profile: default]'),
'default', 'default',
{ model: 'gpt-test', provider: 'openai' }, { model: 'gpt-test', provider: 'openai' },
@@ -326,14 +334,22 @@ describe('bridge run final context usage', () => {
const nsp = makeNamespace(emit) const nsp = makeNamespace(emit)
const socket = makeSocket() const socket = makeSocket()
const state = makeState() const state = makeState()
state.bridgeContext = { fixedContextTokens: 20_000 }
const sessionMap = new Map([['session-1', state]]) const sessionMap = new Map([['session-1', state]])
getCachedBridgeContextOverheadMock.mockReturnValue(20_000)
updateMessageContextTokenUsageMock.mockImplementation((sid: string, targetState: any, targetEmit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, targetState, targetEmit, 20_000 + messageTokens, usage))
const bridge = { const bridge = {
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }), chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
contextEstimate: vi.fn(), contextEstimate: vi.fn(),
streamOutput: vi.fn(async function* () { streamOutput: vi.fn(async function* () {
yield {
run_id: 'run-1',
done: false,
status: 'running',
events: [{
event: 'bridge.context.ready',
fixed_context_tokens: 20_000,
system_prompt_tokens: 3_000,
tool_tokens: 17_000,
}],
}
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' } yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
}), }),
} as any } as any
@@ -365,6 +381,80 @@ describe('bridge run final context usage', () => {
})) }))
}) })
it('keeps bridge context ready updates on the snapshot-aware token baseline', async () => {
const emit = vi.fn()
const nsp = makeNamespace(emit)
const socket = makeSocket()
const state = makeState()
const sessionMap = new Map([['session-1', state]])
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 28_000, outputTokens: 0 })
buildDbHistoryMock.mockResolvedValue([
{ role: 'user', content: 'very large old context' },
{ role: 'assistant', content: 'large old response' },
{ role: 'user', content: 'hello' },
])
buildSnapshotAwareHistoryMock.mockResolvedValue([
{ role: 'user', content: '[Previous context summary]\n\nsmall summary' },
{ role: 'user', content: 'hello' },
])
estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => {
if (messages?.[0]?.content?.includes('small summary')) {
return { inputTokens: 9_000, outputTokens: 0 }
}
return { inputTokens: 28_000, outputTokens: 0 }
})
const bridge = {
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
contextEstimate: vi.fn(),
streamOutput: vi.fn(async function* () {
yield {
run_id: 'run-1',
done: false,
status: 'running',
events: [{
event: 'bridge.context.ready',
fixed_context_tokens: 10_000,
system_prompt_tokens: 2_000,
tool_tokens: 8_000,
}],
}
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
}),
} as any
const { handleBridgeRun } = await import('../../packages/server/src/services/hermes/run-chat/handle-bridge-run')
await handleBridgeRun(
nsp,
socket,
{ input: 'hello', session_id: 'session-1' },
'default',
sessionMap,
bridge,
false,
vi.fn(),
vi.fn(),
)
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
'session-1',
state,
expect.any(Function),
9_000,
{ inputTokens: 28_000, outputTokens: 0 },
)
expect(updateMessageContextTokenUsageMock).not.toHaveBeenCalledWith(
'session-1',
state,
expect.any(Function),
28_000,
{ inputTokens: 28_000, outputTokens: 0 },
)
expect(state.contextTokens).toBe(19_000)
expect(emit).toHaveBeenCalledWith('run.completed', expect.objectContaining({
contextTokens: 19_000,
}))
})
it('persists pending tool marker text before a bridge run completes', async () => { it('persists pending tool marker text before a bridge run completes', async () => {
const emit = vi.fn() const emit = vi.fn()
const nsp = makeNamespace(emit) const nsp = makeNamespace(emit)
@@ -502,6 +592,7 @@ describe('bridge run final context usage', () => {
chat: vi.fn().mockRejectedValue(new Error('bridge timeout')), chat: vi.fn().mockRejectedValue(new Error('bridge timeout')),
contextEstimate: vi.fn().mockResolvedValue({ contextEstimate: vi.fn().mockResolvedValue({
token_count: 54321, token_count: 54321,
fixed_context_tokens: 54303,
message_count: 1, message_count: 1,
tool_count: 4, tool_count: 4,
system_prompt_chars: 13, system_prompt_chars: 13,
+110 -5
View File
@@ -175,7 +175,7 @@ describe('run chat compression trigger', () => {
) )
}) })
it('uses full context estimates for compression threshold decisions', async () => { it('uses local context estimates for compression threshold decisions', async () => {
const messages = Array.from({ length: 10 }, (_, index) => ({ const messages = Array.from({ length: 10 }, (_, index) => ({
id: index + 1, id: index + 1,
session_id: 'session-1', session_id: 'session-1',
@@ -191,7 +191,7 @@ describe('run chat compression trigger', () => {
getSessionDetailMock.mockReturnValue({ messages }) getSessionDetailMock.mockReturnValue({ messages })
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 1_000, outputTokens: 0 }) calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 1_000, outputTokens: 0 })
compressorCompressMock.mockResolvedValue({ compressorCompressMock.mockResolvedValue({
messages: [{ role: 'user', content: 'compressed by full context estimate' }], messages: [{ role: 'user', content: 'compressed by local context estimate' }],
meta: { meta: {
compressed: true, compressed: true,
llmCompressed: true, llmCompressed: true,
@@ -215,7 +215,7 @@ describe('run chat compression trigger', () => {
vi.fn(async () => 120_000), vi.fn(async () => 120_000),
) )
expect(history).toEqual([{ role: 'user', content: 'compressed by full context estimate' }]) expect(history).toEqual([{ role: 'user', content: 'compressed by local context estimate' }])
expect(compressorCompressMock).toHaveBeenCalledTimes(1) expect(compressorCompressMock).toHaveBeenCalledTimes(1)
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith( expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
'session-1', 'session-1',
@@ -226,7 +226,7 @@ describe('run chat compression trigger', () => {
) )
}) })
it('emits full context token usage when the full estimate is under threshold', async () => { it('emits local context token usage when the local estimate is under threshold', async () => {
const messages = Array.from({ length: 10 }, (_, index) => ({ const messages = Array.from({ length: 10 }, (_, index) => ({
id: index + 1, id: index + 1,
session_id: 'session-1', session_id: 'session-1',
@@ -257,7 +257,10 @@ describe('run chat compression trigger', () => {
) )
expect(history).toHaveLength(9) expect(history).toHaveLength(9)
expect(contextTokenEstimator).toHaveBeenCalledWith(expect.arrayContaining([{ role: 'user', content: 'message 0' }])) expect(contextTokenEstimator).toHaveBeenCalledWith(
expect.arrayContaining([{ role: 'user', content: 'message 0' }]),
1_900,
)
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({ expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
event: 'usage.updated', event: 'usage.updated',
session_id: 'session-1', session_id: 'session-1',
@@ -268,6 +271,108 @@ describe('run chat compression trigger', () => {
expect(compressorCompressMock).not.toHaveBeenCalled() expect(compressorCompressMock).not.toHaveBeenCalled()
}) })
it('includes current input tokens when estimating snapshot-aware context', async () => {
const messages = Array.from({ length: 10 }, (_, index) => ({
id: index + 1,
session_id: 'session-1',
role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant',
content: `message ${index}`,
timestamp: index + 1,
tool_call_id: null,
tool_calls: null,
tool_name: null,
finish_reason: null,
reasoning_content: null,
}))
getSessionDetailMock.mockReturnValue({ messages })
getCompressionSnapshotMock.mockReturnValue({
summary: 'previous summary',
lastMessageIndex: 4,
messageCountAtTime: 5,
})
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 10, outputTokens: 0 })
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 1_000, outputTokens: 0 })
const emit = vi.fn()
const contextTokenEstimator = vi.fn(async (_messages, messageTokens: number) => 20_000 + messageTokens)
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
await buildCompressedHistory(
'session-1',
'default',
'http://upstream',
undefined,
emit,
new Map(),
{},
contextTokenEstimator,
700,
)
expect(contextTokenEstimator).toHaveBeenCalledWith(expect.any(Array), 1_700)
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
contextTokens: 21_700,
}))
expect(compressorCompressMock).not.toHaveBeenCalled()
})
it('keeps current input tokens in the compression completed context total', async () => {
const messages = Array.from({ length: 10 }, (_, index) => ({
id: index + 1,
session_id: 'session-1',
role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant',
content: `message ${index}`,
timestamp: index + 1,
tool_call_id: null,
tool_calls: null,
tool_name: null,
finish_reason: null,
reasoning_content: null,
}))
getSessionDetailMock.mockReturnValue({ messages })
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 100, outputTokens: 0 })
estimateUsageTokensFromMessagesMock.mockImplementation((items: any[]) => {
if (items?.[0]?.content === 'compressed result') return { inputTokens: 1_000, outputTokens: 0 }
return { inputTokens: 100, outputTokens: 0 }
})
compressorCompressMock.mockResolvedValue({
messages: [{ role: 'user', content: 'compressed result' }],
meta: {
compressed: true,
llmCompressed: true,
totalMessages: 9,
summaryTokenEstimate: 1,
verbatimCount: 0,
compressedStartIndex: 0,
},
})
const emit = vi.fn()
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
await buildCompressedHistory(
'session-1',
'default',
'http://upstream',
undefined,
emit,
new Map(),
{},
vi.fn(async () => 120_000),
700,
)
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
'session-1',
expect.any(Object),
emit,
1_700,
{ inputTokens: 100, outputTokens: 0 },
)
expect(emit).toHaveBeenCalledWith('compression.completed', expect.objectContaining({
afterTokens: 1_700,
contextTokens: 1_700,
}))
})
it('throws when fixed prompt and tool schemas exceed threshold before any history exists', async () => { it('throws when fixed prompt and tool schemas exceed threshold before any history exists', async () => {
getSessionDetailMock.mockReturnValue({ messages: [] }) getSessionDetailMock.mockReturnValue({ messages: [] })
const emit = vi.fn() const emit = vi.fn()
+129
View File
@@ -0,0 +1,129 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
const getSessionMock = vi.fn()
const getSessionDetailPaginatedMock = vi.fn()
const getCompressionSnapshotMock = vi.fn()
const estimateUsageTokensFromMessagesMock = vi.fn()
const buildDbHistoryMock = vi.fn()
const buildSnapshotAwareHistoryMock = vi.fn()
vi.mock('../../packages/server/src/db/hermes/session-store', () => ({
getSession: getSessionMock,
createSession: vi.fn(),
addMessage: vi.fn(),
updateSessionStats: vi.fn(),
getSessionDetailPaginated: getSessionDetailPaginatedMock,
}))
vi.mock('../../packages/server/src/db/hermes/usage-store', () => ({
updateUsage: vi.fn(),
}))
vi.mock('../../packages/server/src/db/hermes/compression-snapshot', () => ({
getCompressionSnapshot: getCompressionSnapshotMock,
}))
vi.mock('../../packages/server/src/lib/context-compressor', () => ({
SUMMARY_PREFIX: '[Previous context summary]',
countTokens: vi.fn(() => 0),
}))
vi.mock('../../packages/server/src/services/logger', () => ({
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() },
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () => ({
buildCompressedHistory: vi.fn(),
buildDbHistory: buildDbHistoryMock,
buildSnapshotAwareHistory: buildSnapshotAwareHistoryMock,
getOrCreateSession: vi.fn(),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({
calcAndUpdateUsage: vi.fn(),
estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock,
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/message-format', () => ({
convertHistoryFormat: vi.fn((messages: any[]) => messages),
handleMessage: vi.fn((messages: any[]) => messages),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/content-blocks', () => ({
contentBlocksToString: vi.fn((value: any) => String(value || '')),
extractTextForPreview: vi.fn((value: any) => String(value || '')),
isContentBlockArray: vi.fn(() => false),
convertContentBlocks: vi.fn(),
}))
vi.mock('../../packages/server/src/lib/llm-prompt', () => ({
getSystemPrompt: vi.fn(() => 'system prompt'),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/sse-utils', () => ({
readSseFrames: vi.fn(),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/response-utils', () => ({
extractResponseText: vi.fn(),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/response-stream', () => ({
applyResponseStreamEvent: vi.fn(),
flushResponseRunToDb: vi.fn(),
}))
describe('loadSessionStateFromDb', () => {
beforeEach(() => {
vi.clearAllMocks()
getSessionMock.mockReturnValue({
id: 'session-1',
profile: 'default',
model: 'gpt-test',
provider: 'openai',
})
getSessionDetailPaginatedMock.mockReturnValue({
messages: [
{ role: 'user', content: 'old large context' },
{ role: 'assistant', content: 'old large answer' },
{ role: 'user', content: 'new tail' },
],
})
getCompressionSnapshotMock.mockReturnValue({
summary: 'small summary',
lastMessageIndex: 0,
messageCountAtTime: 1,
})
buildDbHistoryMock.mockResolvedValue([
{ role: 'user', content: 'old large context' },
{ role: 'assistant', content: 'old large answer' },
{ role: 'user', content: 'new tail' },
])
buildSnapshotAwareHistoryMock.mockResolvedValue([
{ role: 'user', content: '[Previous context summary]\n\nsmall summary' },
{ role: 'user', content: 'new tail' },
])
estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => {
if (messages?.[0]?.content?.includes('small summary')) {
return { inputTokens: 9_000, outputTokens: 0 }
}
return { inputTokens: 28_000, outputTokens: 0 }
})
})
it('hydrates contextTokens from the same snapshot-aware history used for bridge runs', async () => {
const { loadSessionStateFromDb } = await import('../../packages/server/src/services/hermes/run-chat/handle-api-run')
const state = await loadSessionStateFromDb('session-1', new Map())
expect(buildSnapshotAwareHistoryMock).toHaveBeenCalledWith(
'session-1',
'default',
expect.any(Array),
{ model: 'gpt-test', provider: 'openai' },
)
expect(state.inputTokens).toBe(28_000)
expect(state.outputTokens).toBe(0)
expect(state.contextTokens).toBe(9_000)
})
})