diff --git a/packages/client/src/stores/hermes/chat.ts b/packages/client/src/stores/hermes/chat.ts index a0c1840..3722d77 100644 --- a/packages/client/src/stores/hermes/chat.ts +++ b/packages/client/src/stores/hermes/chat.ts @@ -85,6 +85,15 @@ export interface Session { workspace?: string | null } +interface CompressionState { + compressing: boolean + messageCount: number + beforeTokens: number + afterTokens: number + compressed: boolean | null + error?: string +} + function uid(): string { return Date.now().toString(36) + Math.random().toString(36).slice(2, 8) } @@ -272,6 +281,8 @@ function mapHermesSession(s: SessionSummary): Session { model: s.model, provider: s.provider || (s as any).billing_provider || '', messageCount: s.message_count, + inputTokens: s.input_tokens, + outputTokens: s.output_tokens, endedAt: s.ended_at != null ? Math.round(s.ended_at * 1000) : null, lastActiveAt: s.last_active != null ? Math.round(s.last_active * 1000) : undefined, workspace: s.workspace || null, @@ -405,18 +416,20 @@ export const useChatStore = defineStore('chat', () => { const isLoadingMessages = ref(false) const isRunActive = computed(() => isStreaming.value) - // Compression state - const compressionState = ref<{ - compressing: boolean - messageCount: number - beforeTokens: number - afterTokens: number - compressed: boolean | null - error?: string - } | null>(null) + // Compression state is scoped per session because sockets can stay joined to + // background sessions while another chat is active. + const compressionStates = ref>(new Map()) + const compressionState = computed(() => { + const sid = activeSessionId.value + return sid ? compressionStates.value.get(sid) || null : null + }) - function setCompressionState(state: typeof compressionState.value) { - compressionState.value = state + function setCompressionState(sessionId: string | null | undefined, state: CompressionState | null) { + if (!sessionId) return + const next = new Map(compressionStates.value) + if (state) next.set(sessionId, state) + else next.delete(sessionId) + compressionStates.value = next } const abortState = ref<{ @@ -438,11 +451,12 @@ export const useChatStore = defineStore('chat', () => { } function clearActiveSession() { + const sid = activeSessionId.value activeSessionId.value = null activeSession.value = null focusMessageId.value = null setAbortState(null) - setCompressionState(null) + setCompressionState(sid, null) removeItem(storageKey()) } @@ -453,10 +467,14 @@ export const useChatStore = defineStore('chat', () => { const fresh = list.map(mapHermesSession) // Preserve already-loaded messages for sessions that are still present, // so we don't blow away the active session's messages on refresh. - const msgsByIdBefore = new Map(sessions.value.map(s => [s.id, s.messages])) + const runtimeByIdBefore = new Map(sessions.value.map(s => [s.id, { + messages: s.messages, + contextTokens: s.contextTokens, + }])) for (const s of fresh) { - const prev = msgsByIdBefore.get(s.id) - if (prev && prev.length) s.messages = prev + const prev = runtimeByIdBefore.get(s.id) + if (prev?.messages?.length) s.messages = prev.messages + if (prev?.contextTokens != null) s.contextTokens = prev.contextTokens } sessions.value = fresh @@ -594,6 +612,7 @@ export const useChatStore = defineStore('chat', () => { } else if (!data.isWorking) { setAbortState(null) } + if (!data.isWorking) setCompressionState(sessionId, null) if (data.inputTokens != null) target.inputTokens = data.inputTokens if (data.outputTokens != null) target.outputTokens = data.outputTokens if ((data as any).contextTokens != null) target.contextTokens = (data as any).contextTokens @@ -613,7 +632,7 @@ export const useChatStore = defineStore('chat', () => { for (const evt of data.events) { const e = evt.data as any if (e.event === 'compression.started') { - setCompressionState({ + setCompressionState(sessionId, { compressing: true, messageCount: e.message_count || 0, beforeTokens: e.token_count || 0, @@ -622,7 +641,7 @@ export const useChatStore = defineStore('chat', () => { }) } else if (e.event === 'compression.completed') { const afterTokens = e.contextTokens || e.afterTokens || 0 - setCompressionState({ + setCompressionState(sessionId, { compressing: false, messageCount: e.totalMessages || 0, beforeTokens: e.beforeTokens || 0, @@ -1385,6 +1404,7 @@ export const useChatStore = defineStore('chat', () => { } else if (!data.isWorking) { setAbortState(null) } + if (!data.isWorking) setCompressionState(sid, null) if (data.inputTokens != null) target.inputTokens = data.inputTokens if (data.outputTokens != null) target.outputTokens = data.outputTokens @@ -1407,7 +1427,7 @@ export const useChatStore = defineStore('chat', () => { const e = evt.data as RunEvent switch (e.event) { case 'compression.started': - setCompressionState({ + setCompressionState(sid, { compressing: true, messageCount: (e as any).message_count || 0, beforeTokens: (e as any).token_count || 0, @@ -1417,7 +1437,7 @@ export const useChatStore = defineStore('chat', () => { break case 'compression.completed': { const afterTokens = (e as any).contextTokens || (e as any).afterTokens || 0 - setCompressionState({ + setCompressionState(sid, { compressing: false, messageCount: (e as any).totalMessages || 0, beforeTokens: (e as any).beforeTokens || 0, @@ -1474,7 +1494,7 @@ export const useChatStore = defineStore('chat', () => { case 'run.started': clearAgentEventMessages(sid) setAbortState(null) - setCompressionState(null) + setCompressionState(sid, null) runProducedAssistantText = false runHadToolActivity = false closeStreamingAssistant() @@ -1502,7 +1522,7 @@ export const useChatStore = defineStore('chat', () => { } case 'compression.started': { - setCompressionState({ + setCompressionState(sid, { compressing: true, messageCount: (evt as any).message_count || 0, beforeTokens: (evt as any).token_count || 0, @@ -1514,7 +1534,7 @@ export const useChatStore = defineStore('chat', () => { case 'compression.completed': { const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0 - setCompressionState({ + setCompressionState(sid, { compressing: false, messageCount: (evt as any).totalMessages || 0, beforeTokens: (evt as any).beforeTokens || 0, @@ -1528,8 +1548,9 @@ export const useChatStore = defineStore('chat', () => { } // Auto-clear after 5s setTimeout(() => { - if (compressionState.value && !compressionState.value.compressing) { - setCompressionState(null) + const state = compressionStates.value.get(sid) + if (state && !state.compressing) { + setCompressionState(sid, null) } }, 5000) break @@ -1966,7 +1987,7 @@ export const useChatStore = defineStore('chat', () => { case 'run.started': clearAgentEventMessages(sid) setAbortState(null) - setCompressionState(null) + setCompressionState(sid, null) runProducedAssistantText = false runHadToolActivity = false closeStreamingAssistant() @@ -1979,7 +2000,7 @@ export const useChatStore = defineStore('chat', () => { break case 'compression.started': { - setCompressionState({ + setCompressionState(sid, { compressing: true, messageCount: (evt as any).message_count || 0, beforeTokens: (evt as any).token_count || 0, @@ -1991,7 +2012,7 @@ export const useChatStore = defineStore('chat', () => { case 'compression.completed': { const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0 - setCompressionState({ + setCompressionState(sid, { compressing: false, messageCount: (evt as any).totalMessages || 0, beforeTokens: (evt as any).beforeTokens || 0, @@ -2004,8 +2025,9 @@ export const useChatStore = defineStore('chat', () => { if (target) target.contextTokens = (evt as any).contextTokens } setTimeout(() => { - if (compressionState.value && !compressionState.value.compressing) { - setCompressionState(null) + const state = compressionStates.value.get(sid) + if (state && !state.compressing) { + setCompressionState(sid, null) } }, 5000) break @@ -2461,6 +2483,7 @@ export const useChatStore = defineStore('chat', () => { } else if (!data.isWorking) { setAbortState(null) } + if (!data.isWorking) setCompressionState(sid, null) if (data.messages?.length && activeSession.value) { activeSession.value.messages = mapHermesMessages(data.messages as any[]) } diff --git a/packages/server/src/lib/context-compressor/index.ts b/packages/server/src/lib/context-compressor/index.ts index a535f52..4f4caa3 100644 --- a/packages/server/src/lib/context-compressor/index.ts +++ b/packages/server/src/lib/context-compressor/index.ts @@ -79,6 +79,7 @@ export interface SummarizerOptions { profile?: string model?: string | null provider?: string | null + workerKey?: string } // ─── Token counting ───────────────────────────────────── @@ -454,6 +455,7 @@ export async function callSummarizer( const bridge = new AgentBridgeClient({ timeoutMs: timeoutMs + 15_000 }) const sessionId = `compress_${Date.now().toString(36)}_${randomUUID().replace(/-/g, '').slice(0, 12)}` + const workerKey = options.workerKey || `${profile}:compression:${sessionId}` try { const result = await bridge.request({ @@ -462,6 +464,7 @@ export async function callSummarizer( message: prompt, conversation_history: convHistory, profile, + worker_key: workerKey, source: 'api_server', wait: true, timeout: Math.ceil(timeoutMs / 1000), @@ -482,7 +485,7 @@ export async function callSummarizer( if (!output) throw new Error('Empty summarization response') return output } finally { - await bridge.destroy(sessionId, profile).catch(() => undefined) + await bridge.destroy(sessionId, profile, workerKey).catch(() => undefined) } } diff --git a/packages/server/src/services/hermes/agent-bridge/client.ts b/packages/server/src/services/hermes/agent-bridge/client.ts index 58d4010..1f6cd26 100644 --- a/packages/server/src/services/hermes/agent-bridge/client.ts +++ b/packages/server/src/services/hermes/agent-bridge/client.ts @@ -166,7 +166,7 @@ export class AgentBridgeClient { private summarizePayload(payload: Record): Record { const action = String(payload.action || '') const summary: Record = { action } - for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile']) { + for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile', 'worker_key']) { if (payload[key] != null) summary[key] = payload[key] } if (Array.isArray(payload.conversation_history)) summary.conversation_history_count = payload.conversation_history.length @@ -569,11 +569,12 @@ export class AgentBridgeClient { }) } - destroy(sessionId: string, profile?: string): Promise { + destroy(sessionId: string, profile?: string, workerKey?: string): Promise { return this.request({ action: 'destroy', session_id: sessionId, ...(profile ? { profile } : {}), + ...(workerKey ? { worker_key: workerKey } : {}), }) } diff --git a/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py b/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py index 856b63a..116d7c2 100755 --- a/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py +++ b/packages/server/src/services/hermes/agent-bridge/hermes_bridge.py @@ -2200,7 +2200,8 @@ class WorkerProcess: STARTUP_TIMEOUT_SECONDS = 120 REQUEST_TIMEOUT_SECONDS = 120 - def __init__(self, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None: + def __init__(self, key: str, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None: + self.key = key or profile or "default" self.profile = profile or "default" self.endpoint = endpoint self.agent_root = agent_root @@ -2263,14 +2264,14 @@ class WorkerProcess: for line in proc.stderr: text = line.rstrip() if text: - print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True) + print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True) - threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.profile}").start() + threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.key}").start() def _wait_ready(self) -> None: proc = self.process if proc is None or proc.stdout is None: - raise RuntimeError(f"profile worker {self.profile} did not start") + raise RuntimeError(f"profile worker {self.key} did not start") lines: queue.Queue[str | None] = queue.Queue() ready_event = threading.Event() @@ -2281,17 +2282,17 @@ class WorkerProcess: if ready_event.is_set(): text = line.rstrip() if text: - print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True) + print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True) else: lines.put(line) finally: lines.put(None) - threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.profile}").start() + threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.key}").start() deadline = time.time() + self.STARTUP_TIMEOUT_SECONDS while time.time() < deadline: if proc.poll() is not None: - raise RuntimeError(f"profile worker {self.profile} exited before ready") + raise RuntimeError(f"profile worker {self.key} exited before ready") try: line = lines.get(timeout=0.1) except queue.Empty: @@ -2301,7 +2302,7 @@ class WorkerProcess: continue text = line.strip() if text: - print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True) + print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True) try: data = json.loads(text) if data.get("event") == "ready": @@ -2310,7 +2311,7 @@ class WorkerProcess: except Exception: pass self.stop() - raise RuntimeError(f"profile worker {self.profile} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s") + raise RuntimeError(f"profile worker {self.key} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s") def stop(self) -> None: with self._lock: @@ -2337,8 +2338,8 @@ class WorkerProcess: return _send_bridge_request(self.endpoint, req, self.REQUEST_TIMEOUT_SECONDS) -def _worker_endpoint(profile: str) -> str: - safe = hashlib.sha256(profile.encode("utf-8")).hexdigest()[:16] +def _worker_endpoint(key: str) -> str: + safe = hashlib.sha256(key.encode("utf-8")).hexdigest()[:16] if os.name == "nt": port_base = int(os.environ.get("HERMES_AGENT_BRIDGE_WORKER_PORT_BASE", "18780")) return f"tcp://127.0.0.1:{port_base + int(safe[:4], 16) % 1000}" @@ -2533,11 +2534,17 @@ class BridgeBroker: self.hermes_home = hermes_home self._workers: dict[str, WorkerProcess] = {} self._run_profile: dict[str, str] = {} + self._run_worker_key: dict[str, str] = {} self._running_run_profile: dict[str, str] = {} + self._running_run_worker_key: dict[str, str] = {} self._session_profile: dict[str, str] = {} + self._session_worker_key: dict[str, str] = {} self._approval_profile: dict[str, str] = {} + self._approval_worker_key: dict[str, str] = {} self._clarify_profile: dict[str, str] = {} + self._clarify_worker_key: dict[str, str] = {} self._compression_profile: dict[str, str] = {} + self._compression_worker_key: dict[str, str] = {} self._lock = threading.RLock() self._stop = threading.Event() self._last_gc = time.time() @@ -2546,58 +2553,73 @@ class BridgeBroker: profile = str(value or "").strip() return profile or "default" - def _worker_for_profile(self, profile: str) -> WorkerProcess: + def _normalize_worker_key(self, profile: str, value: Any = None) -> str: + worker_key = str(value or "").strip() + return worker_key or profile + + def _worker_for_profile(self, profile: str, worker_key: str | None = None) -> WorkerProcess: profile = self._normalize_profile(profile) + key = self._normalize_worker_key(profile, worker_key) with self._lock: - worker = self._workers.get(profile) + worker = self._workers.get(key) if worker is None: - worker = WorkerProcess(profile, _worker_endpoint(profile), self.agent_root, self.hermes_home) - self._workers[profile] = worker + worker = WorkerProcess(key, profile, _worker_endpoint(key), self.agent_root, self.hermes_home) + self._workers[key] = worker return worker - def _profile_for_run(self, run_id: str) -> str: + def _route_for_run(self, run_id: str) -> tuple[str, str | None]: with self._lock: profile = self._run_profile.get(run_id) + worker_key = self._run_worker_key.get(run_id) if not profile: raise KeyError(f"unknown run: {run_id}") - return profile + return profile, worker_key - def _profile_for_session(self, session_id: str, fallback_profile: Any = None) -> str: + def _route_for_session(self, session_id: str, fallback_profile: Any = None, worker_key: Any = None) -> tuple[str, str | None]: with self._lock: profile = self._session_profile.get(session_id) + stored_worker_key = self._session_worker_key.get(session_id) if not profile: fallback = self._normalize_profile(fallback_profile) if fallback_profile is not None and fallback: - return fallback + return fallback, self._normalize_worker_key(fallback, worker_key) raise KeyError(f"unknown session: {session_id}") - return profile + return profile, self._normalize_worker_key(profile, worker_key) if worker_key is not None else stored_worker_key - def _record_response_routes(self, profile: str, resp: dict[str, Any]) -> None: + def _record_response_routes(self, profile: str, worker_key: str, resp: dict[str, Any]) -> None: run_id = str(resp.get("run_id") or "") session_id = str(resp.get("session_id") or "") with self._lock: if run_id: self._run_profile[run_id] = profile + self._run_worker_key[run_id] = worker_key if resp.get("status") == "running": self._running_run_profile[run_id] = profile + self._running_run_worker_key[run_id] = worker_key else: self._running_run_profile.pop(run_id, None) + self._running_run_worker_key.pop(run_id, None) if session_id: self._session_profile[session_id] = profile + self._session_worker_key[session_id] = worker_key for event in resp.get("events") or []: if not isinstance(event, dict): continue approval_id = str(event.get("approval_id") or "") if approval_id: self._approval_profile[approval_id] = profile + self._approval_worker_key[approval_id] = worker_key clarify_id = str(event.get("clarify_id") or "") if clarify_id: self._clarify_profile[clarify_id] = profile + self._clarify_worker_key[clarify_id] = worker_key request_id = str(event.get("request_id") or "") if event.get("event") == "bridge.compression.requested" and request_id: self._compression_profile[request_id] = profile + self._compression_worker_key[request_id] = worker_key if event.get("event") in {"bridge.compression.completed", "bridge.compression.failed"} and request_id: self._compression_profile.pop(request_id, None) + self._compression_worker_key.pop(request_id, None) def stop(self) -> None: self._stop.set() @@ -2605,20 +2627,29 @@ class BridgeBroker: workers = list(self._workers.values()) self._workers.clear() self._run_profile.clear() + self._run_worker_key.clear() self._running_run_profile.clear() + self._running_run_worker_key.clear() self._session_profile.clear() + self._session_worker_key.clear() self._approval_profile.clear() + self._approval_worker_key.clear() self._clarify_profile.clear() + self._clarify_worker_key.clear() self._compression_profile.clear() + self._compression_worker_key.clear() for worker in workers: worker.stop() - def _forward(self, profile: str, req: dict[str, Any]) -> dict[str, Any]: - worker = self._worker_for_profile(profile) + def _forward(self, profile: str, req: dict[str, Any], worker_key: str | None = None) -> dict[str, Any]: + profile = self._normalize_profile(profile) + key = self._normalize_worker_key(profile, worker_key) + worker = self._worker_for_profile(profile, key) forwarded = dict(req) forwarded["profile"] = profile + forwarded.pop("worker_key", None) resp = worker.request(forwarded) - self._record_response_routes(profile, resp) + self._record_response_routes(profile, key, resp) return resp def handle(self, req: dict[str, Any]) -> dict[str, Any]: @@ -2629,15 +2660,16 @@ class BridgeBroker: if action == "ping": with self._lock: worker_details = { - profile: { + key: { "running": worker.running, "pid": worker.pid, "endpoint": worker.endpoint, + "profile": getattr(worker, "profile", key), "last_used_at": worker.last_used_at, } - for profile, worker in self._workers.items() + for key, worker in self._workers.items() } - workers = {profile: details["running"] for profile, details in worker_details.items()} + workers = {key: details["running"] for key, details in worker_details.items()} sessions_by_profile: dict[str, int] = {} for profile in self._session_profile.values(): sessions_by_profile[profile] = sessions_by_profile.get(profile, 0) + 1 @@ -2664,29 +2696,32 @@ class BridgeBroker: if action == "worker_ping": profile = self._normalize_profile(req.get("profile")) - resp = self._forward(profile, {"action": "ping"}) + worker_key = self._normalize_worker_key(profile, req.get("worker_key")) + resp = self._forward(profile, {"action": "ping"}, worker_key) resp["worker_profile"] = profile + resp["worker_key"] = worker_key return resp if action == "chat": profile = self._normalize_profile(req.get("profile")) - return self._forward(profile, req) + return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key"))) if action == "context_estimate": profile = self._normalize_profile(req.get("profile")) - return self._forward(profile, req) + return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key"))) if action in {"get_result", "get_output"}: - profile = self._profile_for_run(str(req.get("run_id") or "")) - return self._forward(profile, req) + profile, worker_key = self._route_for_run(str(req.get("run_id") or "")) + return self._forward(profile, req, worker_key) if action in {"interrupt", "steer", "command", "goal_evaluate", "goal_pause", "status", "get_history", "destroy"}: session_id = str(req.get("session_id") or "") - profile = self._profile_for_session(session_id, req.get("profile")) - resp = self._forward(profile, req) + profile, worker_key = self._route_for_session(session_id, req.get("profile"), req.get("worker_key") if "worker_key" in req else None) + resp = self._forward(profile, req, worker_key) if action == "destroy": with self._lock: self._session_profile.pop(session_id, None) + self._session_worker_key.pop(session_id, None) return resp if action == "approval_respond": @@ -2695,9 +2730,10 @@ class BridgeBroker: raise ValueError("approval_id is required") with self._lock: profile = self._approval_profile.get(approval_id) + worker_key = self._approval_worker_key.get(approval_id) if not profile: raise KeyError(f"unknown approval request: {approval_id}") - return self._forward(profile, req) + return self._forward(profile, req, worker_key) if action == "clarify_respond": clarify_id = str(req.get("clarify_id") or "").strip() @@ -2705,9 +2741,10 @@ class BridgeBroker: raise ValueError("clarify_id is required") with self._lock: profile = self._clarify_profile.get(clarify_id) + worker_key = self._clarify_worker_key.get(clarify_id) if not profile: raise KeyError(f"unknown clarify request: {clarify_id}") - return self._forward(profile, req) + return self._forward(profile, req, worker_key) if action == "compression_respond": request_id = str(req.get("request_id") or "").strip() @@ -2715,20 +2752,27 @@ class BridgeBroker: raise ValueError("request_id is required") with self._lock: profile = self._compression_profile.get(request_id) + worker_key = self._compression_worker_key.get(request_id) if not profile: raise KeyError(f"unknown compression request: {request_id}") - return self._forward(profile, req) + return self._forward(profile, req, worker_key) if action == "destroy_all": with self._lock: workers = list(self._workers.values()) self._workers.clear() self._run_profile.clear() + self._run_worker_key.clear() self._running_run_profile.clear() + self._running_run_worker_key.clear() self._session_profile.clear() + self._session_worker_key.clear() self._approval_profile.clear() + self._approval_worker_key.clear() self._clarify_profile.clear() + self._clarify_worker_key.clear() self._compression_profile.clear() + self._compression_worker_key.clear() destroyed = 0 for worker in workers: try: @@ -2744,40 +2788,56 @@ class BridgeBroker: if action == "destroy_profile": profile = self._normalize_profile(req.get("profile")) with self._lock: - worker = self._workers.pop(profile, None) + workers = [ + worker + for key, worker in list(self._workers.items()) + if getattr(worker, "profile", key) == profile + ] + for worker in workers: + self._workers.pop(worker.key, None) self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile} + self._run_worker_key = {key: value for key, value in self._run_worker_key.items() if key in self._run_profile} self._running_run_profile = {key: value for key, value in self._running_run_profile.items() if value != profile} + self._running_run_worker_key = {key: value for key, value in self._running_run_worker_key.items() if key in self._running_run_profile} self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile} + self._session_worker_key = {key: value for key, value in self._session_worker_key.items() if key in self._session_profile} self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile} + self._approval_worker_key = {key: value for key, value in self._approval_worker_key.items() if key in self._approval_profile} self._clarify_profile = {key: value for key, value in self._clarify_profile.items() if value != profile} + self._clarify_worker_key = {key: value for key, value in self._clarify_worker_key.items() if key in self._clarify_profile} self._compression_profile = {key: value for key, value in self._compression_profile.items() if value != profile} + self._compression_worker_key = {key: value for key, value in self._compression_worker_key.items() if key in self._compression_profile} - if worker is None or not worker.running: - if worker is not None: - worker.stop() + if not workers: return {"profile": profile, "destroyed": 0} - try: - resp = worker.request({"action": "destroy_all"}) - destroyed = int(resp.get("destroyed") or 0) - except Exception: - destroyed = 0 - finally: - worker.stop() + destroyed = 0 + for worker in workers: + if not worker.running: + worker.stop() + continue + try: + resp = worker.request({"action": "destroy_all"}) + destroyed += int(resp.get("destroyed") or 0) + except Exception: + pass + finally: + worker.stop() return {"profile": profile, "destroyed": destroyed} if action == "list": sessions: list[Any] = [] with self._lock: workers = list(self._workers.items()) - for profile, worker in workers: + for key, worker in workers: if not worker.running: continue try: resp = worker.request({"action": "list"}) for session in resp.get("sessions") or []: if isinstance(session, dict): - session.setdefault("profile", profile) + session.setdefault("profile", getattr(worker, "profile", key)) + session.setdefault("worker_key", key) sessions.append(session) except Exception: pass @@ -2826,12 +2886,12 @@ class BridgeBroker: self._last_gc = now with self._lock: idle = [ - profile for profile, worker in self._workers.items() + key for key, worker in self._workers.items() if worker.running and now - worker.last_used_at > self.IDLE_TIMEOUT_SECONDS ] - for profile in idle: + for key in idle: with self._lock: - worker = self._workers.pop(profile, None) + worker = self._workers.pop(key, None) if worker: worker.stop() diff --git a/packages/server/src/services/hermes/run-chat/compression.ts b/packages/server/src/services/hermes/run-chat/compression.ts index ec7d262..89dc625 100644 --- a/packages/server/src/services/hermes/run-chat/compression.ts +++ b/packages/server/src/services/hermes/run-chat/compression.ts @@ -195,7 +195,8 @@ export async function buildCompressedHistory( emit: (event: string, payload: any) => void, sessionMap: Map, modelContext: { model?: string | null; provider?: string | null } = {}, - contextTokenEstimator?: (messages: ChatMessage[]) => Promise, + contextTokenEstimator?: (messages: ChatMessage[], messageTokens: number) => Promise, + currentInputTokens = 0, ): Promise { try { let history = await buildDbHistory(sessionId, { excludeLastUser: true }) @@ -213,14 +214,18 @@ export async function buildCompressedHistory( } const cState = getOrCreateSession(sessionMap, sessionId) const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit) - const estimateFullContextTokens = async (messages: ChatMessage[], fallback: number) => { + const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0 + ? Math.floor(currentInputTokens) + : 0 + const estimateLocalContextTokens = async (messages: ChatMessage[], messageTokens: number) => { + const localMessageTokens = Math.max(0, Math.floor(messageTokens)) try { - const estimate = await contextTokenEstimator?.(messages) + const estimate = await contextTokenEstimator?.(messages, localMessageTokens) if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate) } catch (err) { - logger.warn(err, '[context-compress] session=%s: full context token estimate failed; using message-only estimate', sessionId) + logger.warn(err, '[context-compress] session=%s: fixed context token estimate failed; using message-only estimate', sessionId) } - return fallback + return localMessageTokens } const emitContextUsage = (contextTokens: number) => { cState.contextTokens = contextTokens @@ -236,10 +241,10 @@ export async function buildCompressedHistory( let totalTokens = messageOnlyTotalTokens if (history.length === 0) { - totalTokens = await estimateFullContextTokens([], 0) + totalTokens = await estimateLocalContextTokens([], Math.max(currentRunInputTokens, messageOnlyTotalTokens)) if (totalTokens > triggerTokens) { throw new ContextWindowTooSmallError( - `Context window is too small: system prompt and tool schemas already use ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, or disable some tools.`, + `Context window is too small: fixed prompt/tool overhead plus the current input uses ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, shorten the input, or disable some tools.`, ) } if (totalTokens > 0) emitContextUsage(totalTokens) @@ -254,13 +259,15 @@ export async function buildCompressedHistory( sessionId, snapshot.lastMessageIndex, history.length) const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history const staleUsage = estimateUsageTokensFromMessages(staleHistory) - totalTokens = await estimateFullContextTokens(staleHistory, staleUsage.inputTokens + staleUsage.outputTokens) + const staleMessageTokens = staleUsage.inputTokens + staleUsage.outputTokens + const staleRunMessageTokens = Math.max(staleMessageTokens + currentRunInputTokens, messageOnlyTotalTokens) + totalTokens = await estimateLocalContextTokens(staleHistory, staleRunMessageTokens) emitContextUsage(totalTokens) logger.info({ sessionId, profile, messages: staleHistory.length, - messageOnlyTokens: staleUsage.inputTokens + staleUsage.outputTokens, + messageOnlyTokens: staleRunMessageTokens, fullContextTokens: totalTokens, triggerTokens, decision: totalTokens > triggerTokens ? 'compress' : 'skip', @@ -272,13 +279,15 @@ export async function buildCompressedHistory( const newMessages = history.slice(snapshot.lastMessageIndex + 1) const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory) - totalTokens = await estimateFullContextTokens(snapshotHistory, snapshotUsage.inputTokens + snapshotUsage.outputTokens) + const snapshotMessageTokens = snapshotUsage.inputTokens + snapshotUsage.outputTokens + const snapshotRunMessageTokens = Math.max(snapshotMessageTokens + currentRunInputTokens, messageOnlyTotalTokens) + totalTokens = await estimateLocalContextTokens(snapshotHistory, snapshotRunMessageTokens) emitContextUsage(totalTokens) logger.info({ sessionId, profile, messages: snapshotHistory.length, - messageOnlyTokens: snapshotUsage.inputTokens + snapshotUsage.outputTokens, + messageOnlyTokens: snapshotRunMessageTokens, fullContextTokens: totalTokens, triggerTokens, decision: totalTokens > triggerTokens ? 'compress' : 'skip', @@ -289,22 +298,25 @@ export async function buildCompressedHistory( if (totalTokens <= triggerTokens) { history = snapshotHistory } else { - history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor) + history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens) } } else if (snapshot && staleSnapshot) { if (totalTokens <= triggerTokens) { history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history } else { - history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor) + history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens) } } else { - totalTokens = await estimateFullContextTokens(history, totalTokens) + const historyUsage = estimateUsageTokensFromMessages(history) + const historyMessageTokens = historyUsage.inputTokens + historyUsage.outputTokens + const runMessageTokens = Math.max(historyMessageTokens + currentRunInputTokens, messageOnlyTotalTokens) + totalTokens = await estimateLocalContextTokens(history, runMessageTokens) emitContextUsage(totalTokens) logger.info({ sessionId, profile, messages: history.length, - messageOnlyTokens: messageOnlyTotalTokens, + messageOnlyTokens: runMessageTokens, fullContextTokens: totalTokens, triggerTokens, decision: totalTokens > triggerTokens ? 'compress' : 'skip', @@ -318,7 +330,7 @@ export async function buildCompressedHistory( if (totalTokens <= triggerTokens) { logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens) } else { - history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor) + history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens) } } @@ -342,8 +354,12 @@ export async function compressHistory( sessionMap: Map, modelContext: { model?: string | null; provider?: string | null } = {}, compressionConfig?: Partial, + currentInputTokens = 0, ): Promise { const msgCount = newMessagesOnly ? newMessagesOnly.length : history.length + const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0 + ? Math.floor(currentInputTokens) + : 0 pushState(sessionMap, sessionId, 'compression.started', { event: 'compression.started', message_count: msgCount, token_count: totalTokens, }) @@ -353,14 +369,22 @@ export async function compressHistory( try { const session = getSession(sessionId) + const summarizerProfile = session?.profile || 'default' const compressor = new ChatContextCompressor({ config: compressionConfig }) const result = await compressor.compress(history, upstream, apiKey, sessionId, { - profile: session?.profile, + profile: summarizerProfile, model: modelContext.model || session?.model, provider: modelContext.provider || session?.provider, + workerKey: `${summarizerProfile}:compression:${sessionId}`, }) const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit) const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens + const resultUsage = estimateUsageTokensFromMessages(result.messages) + const resultMessageTokens = resultUsage.inputTokens + resultUsage.outputTokens + const compressedRunMessageTokens = Math.max( + compressedAfterTokens, + resultMessageTokens + currentRunInputTokens, + ) const compressedMeta: any = { event: 'compression.completed' as const, compressed: result.meta.compressed, @@ -368,15 +392,15 @@ export async function compressHistory( totalMessages: result.meta.totalMessages, resultMessages: result.messages.length, beforeTokens: totalTokens, - afterTokens: compressedAfterTokens, + afterTokens: compressedRunMessageTokens, summaryTokens: result.meta.summaryTokenEstimate, verbatimCount: result.meta.verbatimCount, compressedStartIndex: result.meta.compressedStartIndex, } replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta) logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)', - sessionId, result.messages.length, compressedAfterTokens, totalTokens) - const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedAfterTokens, afterTokens) + sessionId, result.messages.length, compressedRunMessageTokens, totalTokens) + const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedRunMessageTokens, afterTokens) if (compressedContextTokens != null) { compressedMeta.contextTokens = compressedContextTokens } @@ -403,6 +427,7 @@ export async function compressHistory( resultMessages: msgCount, beforeTokens: totalTokens, afterTokens: totalTokens, + contextTokens: totalTokens, summaryTokens: 0, verbatimCount: msgCount, compressedStartIndex: -1, @@ -458,10 +483,12 @@ export async function forceCompressBridgeHistory( }, '[chat-run-socket] bridge forced compression started') const compressor = new ChatContextCompressor({ config: compressionConfig.compressor }) + const summarizerProfile = session?.profile || profile || 'default' const result = await compressor.compress(history, upstream, apiKey, sessionId, { - profile: session?.profile || profile, + profile: summarizerProfile, model: session?.model, provider: session?.provider, + workerKey: `${summarizerProfile}:compression:${sessionId}`, }) const compressedMessages = result.messages.map(m => { const msg: any = { role: m.role, content: m.content } diff --git a/packages/server/src/services/hermes/run-chat/handle-api-run.ts b/packages/server/src/services/hermes/run-chat/handle-api-run.ts index 3334400..b04154d 100644 --- a/packages/server/src/services/hermes/run-chat/handle-api-run.ts +++ b/packages/server/src/services/hermes/run-chat/handle-api-run.ts @@ -18,7 +18,7 @@ import { convertHistoryFormat } from './message-format' import { readSseFrames } from './sse-utils' import { extractResponseText } from './response-utils' import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream' -import { buildCompressedHistory, getOrCreateSession } from './compression' +import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, getOrCreateSession } from './compression' import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage' import { handleMessage } from './message-format' import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor' @@ -37,6 +37,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map= 0 && snapshot.lastMessageIndex < messages.length) { const newMessages = messages.slice(snapshot.lastMessageIndex + 1) @@ -49,6 +50,20 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map | } } +function bridgeContextMatches( + state: SessionState, + expected: { profile: string; model?: string | null; provider?: string | null }, +): boolean { + const context = state.bridgeContext + if (!context) return false + if (context.profile && context.profile !== expected.profile) return false + if (expected.model && context.model && context.model !== expected.model) return false + if (expected.provider && context.provider && context.provider !== expected.provider) return false + return true +} + +async function ensureBridgeFixedContext(args: { + sessionId: string + profile: string + model?: string | null + provider?: string | null + instructions: string + state: SessionState + bridge: AgentBridgeClient + refresh?: boolean +}): Promise { + const cached = bridgeContextMatches(args.state, args) + ? getCachedBridgeContextOverhead(args.state) + : undefined + if (!args.refresh && cached != null) return cached + + try { + const estimate = await args.bridge.contextEstimate( + args.sessionId, + [], + args.instructions, + args.profile, + { model: args.model ?? undefined, provider: args.provider ?? undefined }, + ) + cacheBridgeContext(args.state, estimate) + const fixedContextTokens = getCachedBridgeContextOverhead(args.state) + bridgeLogger.info({ + sessionId: args.sessionId, + profile: args.profile, + model: args.model, + provider: args.provider, + toolCount: estimate.tool_count, + systemPromptChars: estimate.system_prompt_chars, + fixedContextTokens, + }, '[chat-run-socket] fixed context estimate') + return fixedContextTokens + } catch (err) { + bridgeLogger.warn({ + err: err instanceof Error ? { message: err.message, name: err.name } : err, + sessionId: args.sessionId, + profile: args.profile, + model: args.model, + provider: args.provider, + cachedFixedContextTokens: cached, + }, '[chat-run-socket] fixed context estimate failed') + return cached + } +} + export async function handleBridgeRun( nsp: ReturnType, socket: Socket, @@ -195,6 +254,9 @@ export async function handleBridgeRun( const displayInput = data.display_input === undefined ? input : data.display_input const inputStr = displayInput == null ? '' : contentBlocksToString(displayInput) + const actualInputStr = contentBlocksToString(input) + const currentInputUsage = estimateUsageTokensFromMessages([{ role: 'user', content: actualInputStr }]) + const currentInputTokens = currentInputUsage.inputTokens const shouldPersistUserMessage = !skipUserMessage && displayInput !== null const displayRole = data.display_role === 'command' ? 'command' : 'user' let messageId: number | string | undefined @@ -257,33 +319,32 @@ export async function handleBridgeRun( emit, sessionMap, { model: resolvedModel, provider: resolvedProvider }, - async (messages) => { - const cachedOverhead = getCachedBridgeContextOverhead(state) - if (cachedOverhead != null) { - const messageUsage = estimateUsageTokensFromMessages(messages) - return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens - } - const estimate = await bridge.contextEstimate( - session_id, - messages, - fullInstructions, + async (_messages, localMessageTokens) => { + const fixedContextTokens = await ensureBridgeFixedContext({ + sessionId: session_id, profile, - { model: resolvedModel, provider: resolvedProvider }, - ) - cacheBridgeContext(state, estimate) + model: resolvedModel, + provider: resolvedProvider, + instructions: fullInstructions, + state, + bridge, + refresh: true, + }) + const contextTokens = fixedContextTokens == null + ? localMessageTokens + : fixedContextTokens + localMessageTokens bridgeLogger.info({ sessionId: session_id, profile, model: resolvedModel, provider: resolvedProvider, - messages: estimate.message_count, - toolCount: estimate.tool_count, - systemPromptChars: estimate.system_prompt_chars, - fixedContextTokens: estimate.fixed_context_tokens, - fullContextTokens: estimate.token_count, - }, '[chat-run-socket] full context estimate') - return estimate.token_count + fixedContextTokens, + messageTokens: localMessageTokens, + contextTokens, + }, '[chat-run-socket] local context estimate') + return contextTokens }, + currentInputTokens, ) const bridgeHistory = history @@ -349,6 +410,8 @@ export async function handleBridgeRun( dequeueNextQueuedRun, fullInstructions, { model: resolvedModel, provider: resolvedProvider }, + currentInputTokens, + shouldPersistUserMessage && displayRole === 'user', data.model_groups, ) if (chunk.done) break @@ -417,61 +480,68 @@ async function refreshFinalContextUsage(args: { ) const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory) const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens - if (getCachedBridgeContextOverhead(args.state) != null) { - const contextTokens = updateMessageContextTokenUsage( - args.sessionId, - args.state, - args.emit, - finalMessageTokens, - args.usage, - ) - bridgeLogger.info({ - sessionId: args.sessionId, - profile: args.profile, - model: args.model, - provider: args.provider, - messages: finalHistory.length, - fixedContextTokens: args.state.bridgeContext?.fixedContextTokens, - messageTokens: finalMessageTokens, - fullContextTokens: contextTokens, - }, '[chat-run-socket] final cached context estimate') - return contextTokens - } - const estimate = await args.bridge.contextEstimate( + await ensureBridgeFixedContext({ + sessionId: args.sessionId, + profile: args.profile, + model: args.model, + provider: args.provider, + instructions: args.instructions, + state: args.state, + bridge: args.bridge, + }) + const contextTokens = updateMessageContextTokenUsage( args.sessionId, - finalHistory, - args.instructions, - args.profile, - { model: args.model ?? undefined, provider: args.provider ?? undefined }, + args.state, + args.emit, + finalMessageTokens, + args.usage, ) - cacheBridgeContext(args.state, estimate) - const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0 - ? Math.floor(estimate.token_count) - : undefined - if (contextTokens == null) return args.state.contextTokens - - updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage) bridgeLogger.info({ sessionId: args.sessionId, profile: args.profile, model: args.model, provider: args.provider, - messages: estimate.message_count, - toolCount: estimate.tool_count, - systemPromptChars: estimate.system_prompt_chars, - fullContextTokens: contextTokens, - }, '[chat-run-socket] final full context estimate') + messages: finalHistory.length, + fixedContextTokens: args.state.bridgeContext?.fixedContextTokens, + messageTokens: finalMessageTokens, + contextTokens, + }, '[chat-run-socket] final local context estimate') return contextTokens } catch (err) { bridgeLogger.warn({ err: err instanceof Error ? { message: err.message, name: err.name } : err, sessionId: args.sessionId, profile: args.profile, - }, '[chat-run-socket] final full context estimate failed') + }, '[chat-run-socket] final local context estimate failed') return args.state.contextTokens } } +async function estimateSnapshotAwareMessageTokens(args: { + sessionId: string + profile: string + model?: string | null + provider?: string | null + currentInputTokens?: number + currentInputIncludedInDb?: boolean +}): Promise<{ messageTokens: number; messages: number }> { + const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false }) + const snapshotHistory = await buildSnapshotAwareHistory( + args.sessionId, + args.profile, + dbHistory, + { model: args.model, provider: args.provider }, + ) + const usage = estimateUsageTokensFromMessages(snapshotHistory) + const extraInputTokens = args.currentInputIncludedInDb + ? 0 + : finiteToken(args.currentInputTokens) ?? 0 + return { + messageTokens: usage.inputTokens + usage.outputTokens + extraInputTokens, + messages: snapshotHistory.length, + } +} + async function applyBridgeChunkAsync( nsp: ReturnType, socket: Socket, @@ -486,6 +556,8 @@ async function applyBridgeChunkAsync( dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void, instructions: string, modelContext: { model?: string | null; provider?: string | null }, + currentInputTokens = 0, + currentInputIncludedInDb = true, modelGroups?: RunModelGroup[], ): Promise { if (state.activeRunMarker !== runMarker) { @@ -505,11 +577,19 @@ async function applyBridgeChunkAsync( if (evType === 'bridge.context.ready') { cacheBridgeContext(state, ev) const usage = await calcAndUpdateUsage(sessionId, state, emit) + const snapshotAware = await estimateSnapshotAwareMessageTokens({ + sessionId, + profile, + model: modelContext.model, + provider: modelContext.provider, + currentInputTokens, + currentInputIncludedInDb, + }) updateMessageContextTokenUsage( sessionId, state, emit, - usage.inputTokens + usage.outputTokens, + snapshotAware.messageTokens, usage, ) } else if (evType === 'tool.started') { @@ -646,17 +726,22 @@ async function applyBridgeChunkAsync( const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true }) const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory) const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens - const tokenCount = typeof ev.approx_tokens === 'number' && Number.isFinite(ev.approx_tokens) && ev.approx_tokens > 0 - ? ev.approx_tokens - : messageOnlyTokens + const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0 + ? Math.floor(currentInputTokens) + : 0 + const runMessageTokens = messageOnlyTokens + runInputTokens + const tokenCount = contextTokensWithCachedOverhead(state, runMessageTokens) bridgeLogger.info({ sessionId, profile, bridgeMessages: ev.message_count, dbMessages: bridgeHistory.length, messageOnlyTokens, - fullContextTokens: tokenCount, - source: typeof ev.approx_tokens === 'number' ? 'bridge' : 'message-only-fallback', + currentInputTokens: runInputTokens, + fixedContextTokens: state.bridgeContext?.fixedContextTokens, + contextTokens: tokenCount, + bridgeApproxTokens: ev.approx_tokens, + source: 'local', }, '[chat-run-socket] bridge compression token estimate') const payload = { event: 'compression.started', @@ -674,7 +759,7 @@ async function applyBridgeChunkAsync( sessionId, profile, ev.messages as ChatMessage[], - typeof ev.approx_tokens === 'number' ? ev.approx_tokens : undefined, + tokenCount, ) state.bridgeCompressionResults = state.bridgeCompressionResults || {} state.bridgeCompressionResults[String(ev.request_id)] = compressed @@ -689,11 +774,16 @@ async function applyBridgeChunkAsync( const compressionResult = ev.request_id ? state.bridgeCompressionResults?.[String(ev.request_id)] : undefined - const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens) const messageAfterTokens = finiteToken(compressionResult?.afterTokens) - const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null - ? contextTokensWithCachedOverhead(state, messageAfterTokens) - : bridgeAfterContextTokens ?? messageAfterTokens + const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0 + ? Math.floor(currentInputTokens) + : 0 + const messageAfterTokensWithInput = messageAfterTokens != null + ? messageAfterTokens + runInputTokens + : undefined + const afterContextTokens = messageAfterTokensWithInput != null + ? contextTokensWithCachedOverhead(state, messageAfterTokensWithInput) + : undefined const payload = { event: 'compression.completed', run_id: chunk.run_id, @@ -703,7 +793,7 @@ async function applyBridgeChunkAsync( totalMessages: compressionResult?.beforeMessages ?? ev.message_count, resultMessages: compressionResult?.resultMessages ?? ev.result_messages, beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens, - afterTokens: messageAfterTokens ?? bridgeAfterContextTokens, + afterTokens: messageAfterTokensWithInput, contextTokens: afterContextTokens, summaryTokens: compressionResult?.summaryTokens, verbatimCount: compressionResult?.verbatimCount, @@ -716,10 +806,8 @@ async function applyBridgeChunkAsync( replaceState(sessionMap, sessionId, 'compression.completed', payload) emit('compression.completed', payload) const usage = await calcAndUpdateUsage(sessionId, state, emit) - if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) { - updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage) - } else { - updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage) + if (messageAfterTokensWithInput != null) { + updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokensWithInput, usage) } } else if (evType === 'bridge.compression.failed') { const payload = { diff --git a/tests/client/chat-store-compression-state.test.ts b/tests/client/chat-store-compression-state.test.ts new file mode 100644 index 0000000..fc15ee7 --- /dev/null +++ b/tests/client/chat-store-compression-state.test.ts @@ -0,0 +1,96 @@ +// @vitest-environment jsdom +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { createPinia, setActivePinia } from 'pinia' + +const chatApi = vi.hoisted(() => ({ + resumeSession: vi.fn(), + registerSessionHandlers: vi.fn(), + unregisterSessionHandlers: vi.fn(), +})) + +vi.mock('@/api/hermes/chat', () => ({ + startRunViaSocket: vi.fn(), + resumeSession: chatApi.resumeSession, + registerSessionHandlers: chatApi.registerSessionHandlers, + unregisterSessionHandlers: chatApi.unregisterSessionHandlers, + getChatRunSocket: vi.fn(() => ({ emit: vi.fn() })), + respondToolApproval: vi.fn(), + respondClarify: vi.fn(), + onPeerUserMessage: vi.fn(() => vi.fn()), + onSessionCommand: vi.fn(() => vi.fn()), +})) + +vi.mock('@/api/client', () => ({ + getActiveProfileName: () => 'default', +})) + +vi.mock('@/api/hermes/sessions', () => ({ + deleteSession: vi.fn(), + fetchSession: vi.fn(), + fetchSessions: vi.fn(), + setSessionModel: vi.fn(), +})) + +vi.mock('@/api/hermes/download', () => ({ + getDownloadUrl: (_path: string, name: string) => `/download/${name}`, +})) + +vi.mock('@/utils/completion-sound', () => ({ + primeCompletionSound: vi.fn(), + playCompletionSound: vi.fn(), +})) + +import { useChatStore, type Session } from '@/stores/hermes/chat' + +function makeSession(id: string): Session { + return { + id, + title: id, + messages: [], + createdAt: Date.now(), + updatedAt: Date.now(), + } +} + +describe('chat store compression state', () => { + beforeEach(() => { + vi.resetAllMocks() + setActivePinia(createPinia()) + chatApi.resumeSession.mockImplementation((sessionId: string, onResumed: (data: any) => void) => { + onResumed({ + session_id: sessionId, + messages: [], + isWorking: sessionId === 'session-1', + events: [], + }) + return {} as any + }) + }) + + it('does not show a background session compression indicator in the active session', async () => { + const store = useChatStore() + store.sessions = [makeSession('session-1'), makeSession('session-2')] + + await store.switchSession('session-1') + const handlers = chatApi.registerSessionHandlers.mock.calls.find(call => call[0] === 'session-1')?.[1] + expect(handlers).toBeTruthy() + + await store.switchSession('session-2') + handlers.onCompressionStarted({ + event: 'compression.started', + session_id: 'session-1', + message_count: 6, + token_count: 1234, + }) + + expect(store.activeSessionId).toBe('session-2') + expect(store.compressionState).toBeNull() + + await store.switchSession('session-1') + expect(store.compressionState).toEqual(expect.objectContaining({ + compressing: true, + messageCount: 6, + beforeTokens: 1234, + })) + }) +}) diff --git a/tests/server/agent-bridge-python-concurrency.test.ts b/tests/server/agent-bridge-python-concurrency.test.ts index fcffb52..f754ccf 100644 --- a/tests/server/agent-bridge-python-concurrency.test.ts +++ b/tests/server/agent-bridge-python-concurrency.test.ts @@ -292,9 +292,11 @@ except RuntimeError as exc: assert "already running" in str(exc) class FakeWorker: - def __init__(self, destroyed): + def __init__(self, destroyed, profile="default", key="default"): self.running = True self.destroyed = destroyed + self.profile = profile + self.key = key self.requests = [] self.stopped = False @@ -310,28 +312,41 @@ broker = bridge.BridgeBroker("ipc:///tmp/unused.sock") profile_worker = FakeWorker(2) broker._workers["default"] = profile_worker broker._run_profile["run-session-a"] = "default" +broker._run_worker_key["run-session-a"] = "default" broker._running_run_profile["run-session-a"] = "default" +broker._running_run_worker_key["run-session-a"] = "default" broker._session_profile["session-a"] = "default" +broker._session_worker_key["session-a"] = "default" broker._approval_profile["approval-a"] = "default" +broker._approval_worker_key["approval-a"] = "default" broker._compression_profile["compression-a"] = "default" +broker._compression_worker_key["compression-a"] = "default" destroy_profile_result = broker.handle({"action": "destroy_profile", "profile": "default"}) assert destroy_profile_result == {"profile": "default", "destroyed": 2} assert profile_worker.stopped assert "default" not in broker._workers assert broker._run_profile == {} +assert broker._run_worker_key == {} assert broker._running_run_profile == {} +assert broker._running_run_worker_key == {} assert broker._session_profile == {} +assert broker._session_worker_key == {} assert broker._approval_profile == {} +assert broker._approval_worker_key == {} assert broker._compression_profile == {} +assert broker._compression_worker_key == {} -worker_a = FakeWorker(1) -worker_b = FakeWorker(3) +worker_a = FakeWorker(1, "default", "a") +worker_b = FakeWorker(3, "work", "b") broker._workers["a"] = worker_a broker._workers["b"] = worker_b -broker._run_profile["run-a"] = "a" -broker._running_run_profile["run-a"] = "a" -broker._session_profile["session-b"] = "b" +broker._run_profile["run-a"] = "default" +broker._run_worker_key["run-a"] = "a" +broker._running_run_profile["run-a"] = "default" +broker._running_run_worker_key["run-a"] = "a" +broker._session_profile["session-b"] = "work" +broker._session_worker_key["session-b"] = "b" destroy_all_result = broker.handle({"action": "destroy_all"}) assert destroy_all_result == {"destroyed": 4} @@ -339,8 +354,11 @@ assert worker_a.stopped assert worker_b.stopped assert broker._workers == {} assert broker._run_profile == {} +assert broker._run_worker_key == {} assert broker._running_run_profile == {} +assert broker._running_run_worker_key == {} assert broker._session_profile == {} +assert broker._session_worker_key == {} `) }) @@ -372,6 +390,69 @@ assert resp["running_sessions_by_profile"] == {"default": 1} `) }) + it('routes worker-keyed broker requests without stopping the worker on session destroy', () => { + runPython(String.raw` +${harness} + +class RoutedWorker: + running = True + pid = 12345 + endpoint = "ipc:///tmp/worker.sock" + last_used_at = 12.5 + + def __init__(self, profile, key): + self.profile = profile + self.key = key + self.requests = [] + self.stopped = False + + def request(self, req): + self.requests.append(req) + action = req.get("action") + if action == "chat": + return {"ok": True, "run_id": "run-compress", "session_id": req["session_id"], "status": "running"} + if action == "get_output": + return {"ok": True, "run_id": req["run_id"], "session_id": "compress-temp", "status": "complete", "done": True} + if action == "destroy": + return {"ok": True, "session_id": req["session_id"], "destroyed": True} + raise AssertionError(f"unexpected action: {action}") + + def stop(self): + self.stopped = True + +broker = bridge.BridgeBroker("ipc:///tmp/unused.sock") +worker = RoutedWorker("default", "default:compression:session-a") +broker._workers[worker.key] = worker + +chat_resp = broker.handle({ + "action": "chat", + "session_id": "compress-temp", + "profile": "default", + "worker_key": worker.key, + "message": "summarize", +}) +assert chat_resp["run_id"] == "run-compress" +assert worker.requests[-1]["profile"] == "default" +assert "worker_key" not in worker.requests[-1] + +broker.handle({"action": "get_output", "run_id": "run-compress"}) +assert worker.requests[-1]["action"] == "get_output" + +destroy_resp = broker.handle({ + "action": "destroy", + "session_id": "compress-temp", + "profile": "default", + "worker_key": worker.key, +}) +assert destroy_resp["destroyed"] is True +assert worker.requests[-1]["action"] == "destroy" +assert not worker.stopped +assert worker.key in broker._workers +assert "compress-temp" not in broker._session_profile +assert "compress-temp" not in broker._session_worker_key +`) + }) + it('restores approval env and clears handlers when a run fails', () => { runPython(String.raw` ${harness} @@ -480,7 +561,7 @@ original_getpid = bridge.os.getpid try: bridge.subprocess.Popen = fake_popen bridge.os.getpid = lambda: 4242 - proc_worker = bridge.WorkerProcess("default", "ipc:///tmp/worker.sock", "/agent", "/home") + proc_worker = bridge.WorkerProcess("default:compression:session-a", "default", "ipc:///tmp/worker.sock", "/agent", "/home") proc_worker._pipe_stderr = lambda: None proc_worker._wait_ready = lambda: None proc_worker.start() diff --git a/tests/server/context-compressor.test.ts b/tests/server/context-compressor.test.ts index 8aa096c..4915cf6 100644 --- a/tests/server/context-compressor.test.ts +++ b/tests/server/context-compressor.test.ts @@ -153,6 +153,42 @@ describe('ChatContextCompressor', () => { expect(saveCompressionSnapshotMock).toHaveBeenCalledWith('s1', 'compressed summary', 6, 10) }) + it('routes summarization through the provided worker key and destroys only the temporary agent session', async () => { + const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor') + const compressor = new ChatContextCompressor({ + config: { headMessageCount: 0, tailMessageCount: 1, summaryBudget: 1000 }, + }) + const messages = [ + { role: 'user', content: 'old context' }, + { role: 'assistant', content: 'old response' }, + { role: 'user', content: 'tail' }, + ] + getCompressionSnapshotMock.mockReturnValue(null) + bridgeRequestMock.mockResolvedValue({ + status: 'completed', + result: { final_response: 'compressed summary' }, + }) + + await compressor.compress(messages, 'http://upstream', undefined, 's1', { + profile: 'default', + workerKey: 'default:compression:s1', + }) + + expect(bridgeRequestMock).toHaveBeenCalledWith(expect.objectContaining({ + action: 'chat', + profile: 'default', + worker_key: 'default:compression:s1', + wait: true, + }), expect.any(Object)) + const compressSessionId = bridgeRequestMock.mock.calls[0][0].session_id + expect(String(compressSessionId)).toMatch(/^compress_/) + expect(bridgeDestroyMock).toHaveBeenCalledWith( + compressSessionId, + 'default', + 'default:compression:s1', + ) + }) + it('does not pre-prune tool results before sending them to the summarizer', async () => { const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor') const compressor = new ChatContextCompressor({ diff --git a/tests/server/run-chat-bridge-final-context.test.ts b/tests/server/run-chat-bridge-final-context.test.ts index d68b79c..8ce0155 100644 --- a/tests/server/run-chat-bridge-final-context.test.ts +++ b/tests/server/run-chat-bridge-final-context.test.ts @@ -127,8 +127,18 @@ describe('bridge run final context usage', () => { buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history) calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 }) estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 }) - getCachedBridgeContextOverheadMock.mockReturnValue(undefined) - contextTokensWithCachedOverheadMock.mockImplementation((_state: any, messageTokens: number) => messageTokens) + getCachedBridgeContextOverheadMock.mockImplementation((state: any) => { + const fixed = state?.bridgeContext?.fixedContextTokens + return typeof fixed === 'number' ? fixed : undefined + }) + contextTokensWithCachedOverheadMock.mockImplementation((state: any, messageTokens: number) => { + const fixed = state?.bridgeContext?.fixedContextTokens + return typeof fixed === 'number' ? fixed + messageTokens : messageTokens + }) + updateMessageContextTokenUsageMock.mockImplementation((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => { + const contextTokens = contextTokensWithCachedOverheadMock(state, messageTokens) + return updateContextTokenUsageMock(sid, state, emit, contextTokens, usage) + }) }) it('refreshes full context tokens when a bridge run completes', async () => { @@ -141,6 +151,7 @@ describe('bridge run final context usage', () => { chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }), contextEstimate: vi.fn().mockResolvedValue({ token_count: 12345, + fixed_context_tokens: 12327, message_count: 2, tool_count: 4, system_prompt_chars: 13, @@ -165,10 +176,7 @@ describe('bridge run final context usage', () => { expect(bridge.contextEstimate).toHaveBeenCalledWith( 'session-1', - [ - { role: 'user', content: 'hello' }, - { role: 'assistant', content: 'done' }, - ], + [], expect.stringContaining('[Current Hermes profile: default]'), 'default', { model: 'gpt-test', provider: 'openai' }, @@ -326,14 +334,22 @@ describe('bridge run final context usage', () => { const nsp = makeNamespace(emit) const socket = makeSocket() const state = makeState() - state.bridgeContext = { fixedContextTokens: 20_000 } const sessionMap = new Map([['session-1', state]]) - getCachedBridgeContextOverheadMock.mockReturnValue(20_000) - updateMessageContextTokenUsageMock.mockImplementation((sid: string, targetState: any, targetEmit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, targetState, targetEmit, 20_000 + messageTokens, usage)) const bridge = { chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }), contextEstimate: vi.fn(), streamOutput: vi.fn(async function* () { + yield { + run_id: 'run-1', + done: false, + status: 'running', + events: [{ + event: 'bridge.context.ready', + fixed_context_tokens: 20_000, + system_prompt_tokens: 3_000, + tool_tokens: 17_000, + }], + } yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' } }), } as any @@ -365,6 +381,80 @@ describe('bridge run final context usage', () => { })) }) + it('keeps bridge context ready updates on the snapshot-aware token baseline', async () => { + const emit = vi.fn() + const nsp = makeNamespace(emit) + const socket = makeSocket() + const state = makeState() + const sessionMap = new Map([['session-1', state]]) + calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 28_000, outputTokens: 0 }) + buildDbHistoryMock.mockResolvedValue([ + { role: 'user', content: 'very large old context' }, + { role: 'assistant', content: 'large old response' }, + { role: 'user', content: 'hello' }, + ]) + buildSnapshotAwareHistoryMock.mockResolvedValue([ + { role: 'user', content: '[Previous context summary]\n\nsmall summary' }, + { role: 'user', content: 'hello' }, + ]) + estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => { + if (messages?.[0]?.content?.includes('small summary')) { + return { inputTokens: 9_000, outputTokens: 0 } + } + return { inputTokens: 28_000, outputTokens: 0 } + }) + const bridge = { + chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }), + contextEstimate: vi.fn(), + streamOutput: vi.fn(async function* () { + yield { + run_id: 'run-1', + done: false, + status: 'running', + events: [{ + event: 'bridge.context.ready', + fixed_context_tokens: 10_000, + system_prompt_tokens: 2_000, + tool_tokens: 8_000, + }], + } + yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' } + }), + } as any + + const { handleBridgeRun } = await import('../../packages/server/src/services/hermes/run-chat/handle-bridge-run') + await handleBridgeRun( + nsp, + socket, + { input: 'hello', session_id: 'session-1' }, + 'default', + sessionMap, + bridge, + false, + vi.fn(), + vi.fn(), + ) + + expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith( + 'session-1', + state, + expect.any(Function), + 9_000, + { inputTokens: 28_000, outputTokens: 0 }, + ) + expect(updateMessageContextTokenUsageMock).not.toHaveBeenCalledWith( + 'session-1', + state, + expect.any(Function), + 28_000, + { inputTokens: 28_000, outputTokens: 0 }, + ) + expect(state.contextTokens).toBe(19_000) + expect(emit).toHaveBeenCalledWith('run.completed', expect.objectContaining({ + contextTokens: 19_000, + })) + }) + it('persists pending tool marker text before a bridge run completes', async () => { const emit = vi.fn() const nsp = makeNamespace(emit) @@ -502,6 +592,7 @@ describe('bridge run final context usage', () => { chat: vi.fn().mockRejectedValue(new Error('bridge timeout')), contextEstimate: vi.fn().mockResolvedValue({ token_count: 54321, + fixed_context_tokens: 54303, message_count: 1, tool_count: 4, system_prompt_chars: 13, diff --git a/tests/server/run-chat-compression.test.ts b/tests/server/run-chat-compression.test.ts index 256c73f..5436577 100644 --- a/tests/server/run-chat-compression.test.ts +++ b/tests/server/run-chat-compression.test.ts @@ -175,7 +175,7 @@ describe('run chat compression trigger', () => { ) }) - it('uses full context estimates for compression threshold decisions', async () => { + it('uses local context estimates for compression threshold decisions', async () => { const messages = Array.from({ length: 10 }, (_, index) => ({ id: index + 1, session_id: 'session-1', @@ -191,7 +191,7 @@ describe('run chat compression trigger', () => { getSessionDetailMock.mockReturnValue({ messages }) calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 1_000, outputTokens: 0 }) compressorCompressMock.mockResolvedValue({ - messages: [{ role: 'user', content: 'compressed by full context estimate' }], + messages: [{ role: 'user', content: 'compressed by local context estimate' }], meta: { compressed: true, llmCompressed: true, @@ -215,7 +215,7 @@ describe('run chat compression trigger', () => { vi.fn(async () => 120_000), ) - expect(history).toEqual([{ role: 'user', content: 'compressed by full context estimate' }]) + expect(history).toEqual([{ role: 'user', content: 'compressed by local context estimate' }]) expect(compressorCompressMock).toHaveBeenCalledTimes(1) expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith( 'session-1', @@ -226,7 +226,7 @@ describe('run chat compression trigger', () => { ) }) - it('emits full context token usage when the full estimate is under threshold', async () => { + it('emits local context token usage when the local estimate is under threshold', async () => { const messages = Array.from({ length: 10 }, (_, index) => ({ id: index + 1, session_id: 'session-1', @@ -257,7 +257,10 @@ describe('run chat compression trigger', () => { ) expect(history).toHaveLength(9) - expect(contextTokenEstimator).toHaveBeenCalledWith(expect.arrayContaining([{ role: 'user', content: 'message 0' }])) + expect(contextTokenEstimator).toHaveBeenCalledWith( + expect.arrayContaining([{ role: 'user', content: 'message 0' }]), + 1_900, + ) expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({ event: 'usage.updated', session_id: 'session-1', @@ -268,6 +271,108 @@ describe('run chat compression trigger', () => { expect(compressorCompressMock).not.toHaveBeenCalled() }) + it('includes current input tokens when estimating snapshot-aware context', async () => { + const messages = Array.from({ length: 10 }, (_, index) => ({ + id: index + 1, + session_id: 'session-1', + role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant', + content: `message ${index}`, + timestamp: index + 1, + tool_call_id: null, + tool_calls: null, + tool_name: null, + finish_reason: null, + reasoning_content: null, + })) + getSessionDetailMock.mockReturnValue({ messages }) + getCompressionSnapshotMock.mockReturnValue({ + summary: 'previous summary', + lastMessageIndex: 4, + messageCountAtTime: 5, + }) + calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 10, outputTokens: 0 }) + estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 1_000, outputTokens: 0 }) + const emit = vi.fn() + const contextTokenEstimator = vi.fn(async (_messages, messageTokens: number) => 20_000 + messageTokens) + + const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression') + await buildCompressedHistory( + 'session-1', + 'default', + 'http://upstream', + undefined, + emit, + new Map(), + {}, + contextTokenEstimator, + 700, + ) + + expect(contextTokenEstimator).toHaveBeenCalledWith(expect.any(Array), 1_700) + expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({ + contextTokens: 21_700, + })) + expect(compressorCompressMock).not.toHaveBeenCalled() + }) + + it('keeps current input tokens in the compression completed context total', async () => { + const messages = Array.from({ length: 10 }, (_, index) => ({ + id: index + 1, + session_id: 'session-1', + role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant', + content: `message ${index}`, + timestamp: index + 1, + tool_call_id: null, + tool_calls: null, + tool_name: null, + finish_reason: null, + reasoning_content: null, + })) + getSessionDetailMock.mockReturnValue({ messages }) + calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 100, outputTokens: 0 }) + estimateUsageTokensFromMessagesMock.mockImplementation((items: any[]) => { + if (items?.[0]?.content === 'compressed result') return { inputTokens: 1_000, outputTokens: 0 } + return { inputTokens: 100, outputTokens: 0 } + }) + compressorCompressMock.mockResolvedValue({ + messages: [{ role: 'user', content: 'compressed result' }], + meta: { + compressed: true, + llmCompressed: true, + totalMessages: 9, + summaryTokenEstimate: 1, + verbatimCount: 0, + compressedStartIndex: 0, + }, + }) + const emit = vi.fn() + + const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression') + await buildCompressedHistory( + 'session-1', + 'default', + 'http://upstream', + undefined, + emit, + new Map(), + {}, + vi.fn(async () => 120_000), + 700, + ) + + expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith( + 'session-1', + expect.any(Object), + emit, + 1_700, + { inputTokens: 100, outputTokens: 0 }, + ) + expect(emit).toHaveBeenCalledWith('compression.completed', expect.objectContaining({ + afterTokens: 1_700, + contextTokens: 1_700, + })) + }) + it('throws when fixed prompt and tool schemas exceed threshold before any history exists', async () => { getSessionDetailMock.mockReturnValue({ messages: [] }) const emit = vi.fn() diff --git a/tests/server/run-chat-load-state.test.ts b/tests/server/run-chat-load-state.test.ts new file mode 100644 index 0000000..0254783 --- /dev/null +++ b/tests/server/run-chat-load-state.test.ts @@ -0,0 +1,129 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const getSessionMock = vi.fn() +const getSessionDetailPaginatedMock = vi.fn() +const getCompressionSnapshotMock = vi.fn() +const estimateUsageTokensFromMessagesMock = vi.fn() +const buildDbHistoryMock = vi.fn() +const buildSnapshotAwareHistoryMock = vi.fn() + +vi.mock('../../packages/server/src/db/hermes/session-store', () => ({ + getSession: getSessionMock, + createSession: vi.fn(), + addMessage: vi.fn(), + updateSessionStats: vi.fn(), + getSessionDetailPaginated: getSessionDetailPaginatedMock, +})) + +vi.mock('../../packages/server/src/db/hermes/usage-store', () => ({ + updateUsage: vi.fn(), +})) + +vi.mock('../../packages/server/src/db/hermes/compression-snapshot', () => ({ + getCompressionSnapshot: getCompressionSnapshotMock, +})) + +vi.mock('../../packages/server/src/lib/context-compressor', () => ({ + SUMMARY_PREFIX: '[Previous context summary]', + countTokens: vi.fn(() => 0), +})) + +vi.mock('../../packages/server/src/services/logger', () => ({ + logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() }, +})) + +vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () => ({ + buildCompressedHistory: vi.fn(), + buildDbHistory: buildDbHistoryMock, + buildSnapshotAwareHistory: buildSnapshotAwareHistoryMock, + getOrCreateSession: vi.fn(), +})) + +vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({ + calcAndUpdateUsage: vi.fn(), + estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock, +})) + +vi.mock('../../packages/server/src/services/hermes/run-chat/message-format', () => ({ + convertHistoryFormat: vi.fn((messages: any[]) => messages), + handleMessage: vi.fn((messages: any[]) => messages), +})) + +vi.mock('../../packages/server/src/services/hermes/run-chat/content-blocks', () => ({ + contentBlocksToString: vi.fn((value: any) => String(value || '')), + extractTextForPreview: vi.fn((value: any) => String(value || '')), + isContentBlockArray: vi.fn(() => false), + convertContentBlocks: vi.fn(), +})) + +vi.mock('../../packages/server/src/lib/llm-prompt', () => ({ + getSystemPrompt: vi.fn(() => 'system prompt'), +})) + +vi.mock('../../packages/server/src/services/hermes/run-chat/sse-utils', () => ({ + readSseFrames: vi.fn(), +})) + +vi.mock('../../packages/server/src/services/hermes/run-chat/response-utils', () => ({ + extractResponseText: vi.fn(), +})) + +vi.mock('../../packages/server/src/services/hermes/run-chat/response-stream', () => ({ + applyResponseStreamEvent: vi.fn(), + flushResponseRunToDb: vi.fn(), +})) + +describe('loadSessionStateFromDb', () => { + beforeEach(() => { + vi.clearAllMocks() + getSessionMock.mockReturnValue({ + id: 'session-1', + profile: 'default', + model: 'gpt-test', + provider: 'openai', + }) + getSessionDetailPaginatedMock.mockReturnValue({ + messages: [ + { role: 'user', content: 'old large context' }, + { role: 'assistant', content: 'old large answer' }, + { role: 'user', content: 'new tail' }, + ], + }) + getCompressionSnapshotMock.mockReturnValue({ + summary: 'small summary', + lastMessageIndex: 0, + messageCountAtTime: 1, + }) + buildDbHistoryMock.mockResolvedValue([ + { role: 'user', content: 'old large context' }, + { role: 'assistant', content: 'old large answer' }, + { role: 'user', content: 'new tail' }, + ]) + buildSnapshotAwareHistoryMock.mockResolvedValue([ + { role: 'user', content: '[Previous context summary]\n\nsmall summary' }, + { role: 'user', content: 'new tail' }, + ]) + estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => { + if (messages?.[0]?.content?.includes('small summary')) { + return { inputTokens: 9_000, outputTokens: 0 } + } + return { inputTokens: 28_000, outputTokens: 0 } + }) + }) + + it('hydrates contextTokens from the same snapshot-aware history used for bridge runs', async () => { + const { loadSessionStateFromDb } = await import('../../packages/server/src/services/hermes/run-chat/handle-api-run') + + const state = await loadSessionStateFromDb('session-1', new Map()) + + expect(buildSnapshotAwareHistoryMock).toHaveBeenCalledWith( + 'session-1', + 'default', + expect.any(Array), + { model: 'gpt-test', provider: 'openai' }, + ) + expect(state.inputTokens).toBe(28_000) + expect(state.outputTokens).toBe(0) + expect(state.contextTokens).toBe(9_000) + }) +})