fix context token resume (#1039)
This commit is contained in:
@@ -85,6 +85,15 @@ export interface Session {
|
||||
workspace?: string | null
|
||||
}
|
||||
|
||||
interface CompressionState {
|
||||
compressing: boolean
|
||||
messageCount: number
|
||||
beforeTokens: number
|
||||
afterTokens: number
|
||||
compressed: boolean | null
|
||||
error?: string
|
||||
}
|
||||
|
||||
function uid(): string {
|
||||
return Date.now().toString(36) + Math.random().toString(36).slice(2, 8)
|
||||
}
|
||||
@@ -272,6 +281,8 @@ function mapHermesSession(s: SessionSummary): Session {
|
||||
model: s.model,
|
||||
provider: s.provider || (s as any).billing_provider || '',
|
||||
messageCount: s.message_count,
|
||||
inputTokens: s.input_tokens,
|
||||
outputTokens: s.output_tokens,
|
||||
endedAt: s.ended_at != null ? Math.round(s.ended_at * 1000) : null,
|
||||
lastActiveAt: s.last_active != null ? Math.round(s.last_active * 1000) : undefined,
|
||||
workspace: s.workspace || null,
|
||||
@@ -405,18 +416,20 @@ export const useChatStore = defineStore('chat', () => {
|
||||
const isLoadingMessages = ref(false)
|
||||
const isRunActive = computed(() => isStreaming.value)
|
||||
|
||||
// Compression state
|
||||
const compressionState = ref<{
|
||||
compressing: boolean
|
||||
messageCount: number
|
||||
beforeTokens: number
|
||||
afterTokens: number
|
||||
compressed: boolean | null
|
||||
error?: string
|
||||
} | null>(null)
|
||||
// Compression state is scoped per session because sockets can stay joined to
|
||||
// background sessions while another chat is active.
|
||||
const compressionStates = ref<Map<string, CompressionState>>(new Map())
|
||||
const compressionState = computed<CompressionState | null>(() => {
|
||||
const sid = activeSessionId.value
|
||||
return sid ? compressionStates.value.get(sid) || null : null
|
||||
})
|
||||
|
||||
function setCompressionState(state: typeof compressionState.value) {
|
||||
compressionState.value = state
|
||||
function setCompressionState(sessionId: string | null | undefined, state: CompressionState | null) {
|
||||
if (!sessionId) return
|
||||
const next = new Map(compressionStates.value)
|
||||
if (state) next.set(sessionId, state)
|
||||
else next.delete(sessionId)
|
||||
compressionStates.value = next
|
||||
}
|
||||
|
||||
const abortState = ref<{
|
||||
@@ -438,11 +451,12 @@ export const useChatStore = defineStore('chat', () => {
|
||||
}
|
||||
|
||||
function clearActiveSession() {
|
||||
const sid = activeSessionId.value
|
||||
activeSessionId.value = null
|
||||
activeSession.value = null
|
||||
focusMessageId.value = null
|
||||
setAbortState(null)
|
||||
setCompressionState(null)
|
||||
setCompressionState(sid, null)
|
||||
removeItem(storageKey())
|
||||
}
|
||||
|
||||
@@ -453,10 +467,14 @@ export const useChatStore = defineStore('chat', () => {
|
||||
const fresh = list.map(mapHermesSession)
|
||||
// Preserve already-loaded messages for sessions that are still present,
|
||||
// so we don't blow away the active session's messages on refresh.
|
||||
const msgsByIdBefore = new Map(sessions.value.map(s => [s.id, s.messages]))
|
||||
const runtimeByIdBefore = new Map(sessions.value.map(s => [s.id, {
|
||||
messages: s.messages,
|
||||
contextTokens: s.contextTokens,
|
||||
}]))
|
||||
for (const s of fresh) {
|
||||
const prev = msgsByIdBefore.get(s.id)
|
||||
if (prev && prev.length) s.messages = prev
|
||||
const prev = runtimeByIdBefore.get(s.id)
|
||||
if (prev?.messages?.length) s.messages = prev.messages
|
||||
if (prev?.contextTokens != null) s.contextTokens = prev.contextTokens
|
||||
}
|
||||
sessions.value = fresh
|
||||
|
||||
@@ -594,6 +612,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
} else if (!data.isWorking) {
|
||||
setAbortState(null)
|
||||
}
|
||||
if (!data.isWorking) setCompressionState(sessionId, null)
|
||||
if (data.inputTokens != null) target.inputTokens = data.inputTokens
|
||||
if (data.outputTokens != null) target.outputTokens = data.outputTokens
|
||||
if ((data as any).contextTokens != null) target.contextTokens = (data as any).contextTokens
|
||||
@@ -613,7 +632,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
for (const evt of data.events) {
|
||||
const e = evt.data as any
|
||||
if (e.event === 'compression.started') {
|
||||
setCompressionState({
|
||||
setCompressionState(sessionId, {
|
||||
compressing: true,
|
||||
messageCount: e.message_count || 0,
|
||||
beforeTokens: e.token_count || 0,
|
||||
@@ -622,7 +641,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
})
|
||||
} else if (e.event === 'compression.completed') {
|
||||
const afterTokens = e.contextTokens || e.afterTokens || 0
|
||||
setCompressionState({
|
||||
setCompressionState(sessionId, {
|
||||
compressing: false,
|
||||
messageCount: e.totalMessages || 0,
|
||||
beforeTokens: e.beforeTokens || 0,
|
||||
@@ -1385,6 +1404,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
} else if (!data.isWorking) {
|
||||
setAbortState(null)
|
||||
}
|
||||
if (!data.isWorking) setCompressionState(sid, null)
|
||||
|
||||
if (data.inputTokens != null) target.inputTokens = data.inputTokens
|
||||
if (data.outputTokens != null) target.outputTokens = data.outputTokens
|
||||
@@ -1407,7 +1427,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
const e = evt.data as RunEvent
|
||||
switch (e.event) {
|
||||
case 'compression.started':
|
||||
setCompressionState({
|
||||
setCompressionState(sid, {
|
||||
compressing: true,
|
||||
messageCount: (e as any).message_count || 0,
|
||||
beforeTokens: (e as any).token_count || 0,
|
||||
@@ -1417,7 +1437,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
break
|
||||
case 'compression.completed': {
|
||||
const afterTokens = (e as any).contextTokens || (e as any).afterTokens || 0
|
||||
setCompressionState({
|
||||
setCompressionState(sid, {
|
||||
compressing: false,
|
||||
messageCount: (e as any).totalMessages || 0,
|
||||
beforeTokens: (e as any).beforeTokens || 0,
|
||||
@@ -1474,7 +1494,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
case 'run.started':
|
||||
clearAgentEventMessages(sid)
|
||||
setAbortState(null)
|
||||
setCompressionState(null)
|
||||
setCompressionState(sid, null)
|
||||
runProducedAssistantText = false
|
||||
runHadToolActivity = false
|
||||
closeStreamingAssistant()
|
||||
@@ -1502,7 +1522,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
}
|
||||
|
||||
case 'compression.started': {
|
||||
setCompressionState({
|
||||
setCompressionState(sid, {
|
||||
compressing: true,
|
||||
messageCount: (evt as any).message_count || 0,
|
||||
beforeTokens: (evt as any).token_count || 0,
|
||||
@@ -1514,7 +1534,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
|
||||
case 'compression.completed': {
|
||||
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
|
||||
setCompressionState({
|
||||
setCompressionState(sid, {
|
||||
compressing: false,
|
||||
messageCount: (evt as any).totalMessages || 0,
|
||||
beforeTokens: (evt as any).beforeTokens || 0,
|
||||
@@ -1528,8 +1548,9 @@ export const useChatStore = defineStore('chat', () => {
|
||||
}
|
||||
// Auto-clear after 5s
|
||||
setTimeout(() => {
|
||||
if (compressionState.value && !compressionState.value.compressing) {
|
||||
setCompressionState(null)
|
||||
const state = compressionStates.value.get(sid)
|
||||
if (state && !state.compressing) {
|
||||
setCompressionState(sid, null)
|
||||
}
|
||||
}, 5000)
|
||||
break
|
||||
@@ -1966,7 +1987,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
case 'run.started':
|
||||
clearAgentEventMessages(sid)
|
||||
setAbortState(null)
|
||||
setCompressionState(null)
|
||||
setCompressionState(sid, null)
|
||||
runProducedAssistantText = false
|
||||
runHadToolActivity = false
|
||||
closeStreamingAssistant()
|
||||
@@ -1979,7 +2000,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
break
|
||||
|
||||
case 'compression.started': {
|
||||
setCompressionState({
|
||||
setCompressionState(sid, {
|
||||
compressing: true,
|
||||
messageCount: (evt as any).message_count || 0,
|
||||
beforeTokens: (evt as any).token_count || 0,
|
||||
@@ -1991,7 +2012,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
|
||||
case 'compression.completed': {
|
||||
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
|
||||
setCompressionState({
|
||||
setCompressionState(sid, {
|
||||
compressing: false,
|
||||
messageCount: (evt as any).totalMessages || 0,
|
||||
beforeTokens: (evt as any).beforeTokens || 0,
|
||||
@@ -2004,8 +2025,9 @@ export const useChatStore = defineStore('chat', () => {
|
||||
if (target) target.contextTokens = (evt as any).contextTokens
|
||||
}
|
||||
setTimeout(() => {
|
||||
if (compressionState.value && !compressionState.value.compressing) {
|
||||
setCompressionState(null)
|
||||
const state = compressionStates.value.get(sid)
|
||||
if (state && !state.compressing) {
|
||||
setCompressionState(sid, null)
|
||||
}
|
||||
}, 5000)
|
||||
break
|
||||
@@ -2461,6 +2483,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||
} else if (!data.isWorking) {
|
||||
setAbortState(null)
|
||||
}
|
||||
if (!data.isWorking) setCompressionState(sid, null)
|
||||
if (data.messages?.length && activeSession.value) {
|
||||
activeSession.value.messages = mapHermesMessages(data.messages as any[])
|
||||
}
|
||||
|
||||
@@ -79,6 +79,7 @@ export interface SummarizerOptions {
|
||||
profile?: string
|
||||
model?: string | null
|
||||
provider?: string | null
|
||||
workerKey?: string
|
||||
}
|
||||
|
||||
// ─── Token counting ─────────────────────────────────────
|
||||
@@ -454,6 +455,7 @@ export async function callSummarizer(
|
||||
|
||||
const bridge = new AgentBridgeClient({ timeoutMs: timeoutMs + 15_000 })
|
||||
const sessionId = `compress_${Date.now().toString(36)}_${randomUUID().replace(/-/g, '').slice(0, 12)}`
|
||||
const workerKey = options.workerKey || `${profile}:compression:${sessionId}`
|
||||
|
||||
try {
|
||||
const result = await bridge.request<AgentBridgeRunResult>({
|
||||
@@ -462,6 +464,7 @@ export async function callSummarizer(
|
||||
message: prompt,
|
||||
conversation_history: convHistory,
|
||||
profile,
|
||||
worker_key: workerKey,
|
||||
source: 'api_server',
|
||||
wait: true,
|
||||
timeout: Math.ceil(timeoutMs / 1000),
|
||||
@@ -482,7 +485,7 @@ export async function callSummarizer(
|
||||
if (!output) throw new Error('Empty summarization response')
|
||||
return output
|
||||
} finally {
|
||||
await bridge.destroy(sessionId, profile).catch(() => undefined)
|
||||
await bridge.destroy(sessionId, profile, workerKey).catch(() => undefined)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ export class AgentBridgeClient {
|
||||
private summarizePayload(payload: Record<string, unknown>): Record<string, unknown> {
|
||||
const action = String(payload.action || '')
|
||||
const summary: Record<string, unknown> = { action }
|
||||
for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile']) {
|
||||
for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile', 'worker_key']) {
|
||||
if (payload[key] != null) summary[key] = payload[key]
|
||||
}
|
||||
if (Array.isArray(payload.conversation_history)) summary.conversation_history_count = payload.conversation_history.length
|
||||
@@ -569,11 +569,12 @@ export class AgentBridgeClient {
|
||||
})
|
||||
}
|
||||
|
||||
destroy(sessionId: string, profile?: string): Promise<AgentBridgeResponse> {
|
||||
destroy(sessionId: string, profile?: string, workerKey?: string): Promise<AgentBridgeResponse> {
|
||||
return this.request({
|
||||
action: 'destroy',
|
||||
session_id: sessionId,
|
||||
...(profile ? { profile } : {}),
|
||||
...(workerKey ? { worker_key: workerKey } : {}),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2200,7 +2200,8 @@ class WorkerProcess:
|
||||
STARTUP_TIMEOUT_SECONDS = 120
|
||||
REQUEST_TIMEOUT_SECONDS = 120
|
||||
|
||||
def __init__(self, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None:
|
||||
def __init__(self, key: str, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None:
|
||||
self.key = key or profile or "default"
|
||||
self.profile = profile or "default"
|
||||
self.endpoint = endpoint
|
||||
self.agent_root = agent_root
|
||||
@@ -2263,14 +2264,14 @@ class WorkerProcess:
|
||||
for line in proc.stderr:
|
||||
text = line.rstrip()
|
||||
if text:
|
||||
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
|
||||
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
|
||||
|
||||
threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.profile}").start()
|
||||
threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.key}").start()
|
||||
|
||||
def _wait_ready(self) -> None:
|
||||
proc = self.process
|
||||
if proc is None or proc.stdout is None:
|
||||
raise RuntimeError(f"profile worker {self.profile} did not start")
|
||||
raise RuntimeError(f"profile worker {self.key} did not start")
|
||||
lines: queue.Queue[str | None] = queue.Queue()
|
||||
ready_event = threading.Event()
|
||||
|
||||
@@ -2281,17 +2282,17 @@ class WorkerProcess:
|
||||
if ready_event.is_set():
|
||||
text = line.rstrip()
|
||||
if text:
|
||||
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
|
||||
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
|
||||
else:
|
||||
lines.put(line)
|
||||
finally:
|
||||
lines.put(None)
|
||||
|
||||
threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.profile}").start()
|
||||
threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.key}").start()
|
||||
deadline = time.time() + self.STARTUP_TIMEOUT_SECONDS
|
||||
while time.time() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(f"profile worker {self.profile} exited before ready")
|
||||
raise RuntimeError(f"profile worker {self.key} exited before ready")
|
||||
try:
|
||||
line = lines.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
@@ -2301,7 +2302,7 @@ class WorkerProcess:
|
||||
continue
|
||||
text = line.strip()
|
||||
if text:
|
||||
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
|
||||
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if data.get("event") == "ready":
|
||||
@@ -2310,7 +2311,7 @@ class WorkerProcess:
|
||||
except Exception:
|
||||
pass
|
||||
self.stop()
|
||||
raise RuntimeError(f"profile worker {self.profile} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s")
|
||||
raise RuntimeError(f"profile worker {self.key} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s")
|
||||
|
||||
def stop(self) -> None:
|
||||
with self._lock:
|
||||
@@ -2337,8 +2338,8 @@ class WorkerProcess:
|
||||
return _send_bridge_request(self.endpoint, req, self.REQUEST_TIMEOUT_SECONDS)
|
||||
|
||||
|
||||
def _worker_endpoint(profile: str) -> str:
|
||||
safe = hashlib.sha256(profile.encode("utf-8")).hexdigest()[:16]
|
||||
def _worker_endpoint(key: str) -> str:
|
||||
safe = hashlib.sha256(key.encode("utf-8")).hexdigest()[:16]
|
||||
if os.name == "nt":
|
||||
port_base = int(os.environ.get("HERMES_AGENT_BRIDGE_WORKER_PORT_BASE", "18780"))
|
||||
return f"tcp://127.0.0.1:{port_base + int(safe[:4], 16) % 1000}"
|
||||
@@ -2533,11 +2534,17 @@ class BridgeBroker:
|
||||
self.hermes_home = hermes_home
|
||||
self._workers: dict[str, WorkerProcess] = {}
|
||||
self._run_profile: dict[str, str] = {}
|
||||
self._run_worker_key: dict[str, str] = {}
|
||||
self._running_run_profile: dict[str, str] = {}
|
||||
self._running_run_worker_key: dict[str, str] = {}
|
||||
self._session_profile: dict[str, str] = {}
|
||||
self._session_worker_key: dict[str, str] = {}
|
||||
self._approval_profile: dict[str, str] = {}
|
||||
self._approval_worker_key: dict[str, str] = {}
|
||||
self._clarify_profile: dict[str, str] = {}
|
||||
self._clarify_worker_key: dict[str, str] = {}
|
||||
self._compression_profile: dict[str, str] = {}
|
||||
self._compression_worker_key: dict[str, str] = {}
|
||||
self._lock = threading.RLock()
|
||||
self._stop = threading.Event()
|
||||
self._last_gc = time.time()
|
||||
@@ -2546,58 +2553,73 @@ class BridgeBroker:
|
||||
profile = str(value or "").strip()
|
||||
return profile or "default"
|
||||
|
||||
def _worker_for_profile(self, profile: str) -> WorkerProcess:
|
||||
def _normalize_worker_key(self, profile: str, value: Any = None) -> str:
|
||||
worker_key = str(value or "").strip()
|
||||
return worker_key or profile
|
||||
|
||||
def _worker_for_profile(self, profile: str, worker_key: str | None = None) -> WorkerProcess:
|
||||
profile = self._normalize_profile(profile)
|
||||
key = self._normalize_worker_key(profile, worker_key)
|
||||
with self._lock:
|
||||
worker = self._workers.get(profile)
|
||||
worker = self._workers.get(key)
|
||||
if worker is None:
|
||||
worker = WorkerProcess(profile, _worker_endpoint(profile), self.agent_root, self.hermes_home)
|
||||
self._workers[profile] = worker
|
||||
worker = WorkerProcess(key, profile, _worker_endpoint(key), self.agent_root, self.hermes_home)
|
||||
self._workers[key] = worker
|
||||
return worker
|
||||
|
||||
def _profile_for_run(self, run_id: str) -> str:
|
||||
def _route_for_run(self, run_id: str) -> tuple[str, str | None]:
|
||||
with self._lock:
|
||||
profile = self._run_profile.get(run_id)
|
||||
worker_key = self._run_worker_key.get(run_id)
|
||||
if not profile:
|
||||
raise KeyError(f"unknown run: {run_id}")
|
||||
return profile
|
||||
return profile, worker_key
|
||||
|
||||
def _profile_for_session(self, session_id: str, fallback_profile: Any = None) -> str:
|
||||
def _route_for_session(self, session_id: str, fallback_profile: Any = None, worker_key: Any = None) -> tuple[str, str | None]:
|
||||
with self._lock:
|
||||
profile = self._session_profile.get(session_id)
|
||||
stored_worker_key = self._session_worker_key.get(session_id)
|
||||
if not profile:
|
||||
fallback = self._normalize_profile(fallback_profile)
|
||||
if fallback_profile is not None and fallback:
|
||||
return fallback
|
||||
return fallback, self._normalize_worker_key(fallback, worker_key)
|
||||
raise KeyError(f"unknown session: {session_id}")
|
||||
return profile
|
||||
return profile, self._normalize_worker_key(profile, worker_key) if worker_key is not None else stored_worker_key
|
||||
|
||||
def _record_response_routes(self, profile: str, resp: dict[str, Any]) -> None:
|
||||
def _record_response_routes(self, profile: str, worker_key: str, resp: dict[str, Any]) -> None:
|
||||
run_id = str(resp.get("run_id") or "")
|
||||
session_id = str(resp.get("session_id") or "")
|
||||
with self._lock:
|
||||
if run_id:
|
||||
self._run_profile[run_id] = profile
|
||||
self._run_worker_key[run_id] = worker_key
|
||||
if resp.get("status") == "running":
|
||||
self._running_run_profile[run_id] = profile
|
||||
self._running_run_worker_key[run_id] = worker_key
|
||||
else:
|
||||
self._running_run_profile.pop(run_id, None)
|
||||
self._running_run_worker_key.pop(run_id, None)
|
||||
if session_id:
|
||||
self._session_profile[session_id] = profile
|
||||
self._session_worker_key[session_id] = worker_key
|
||||
for event in resp.get("events") or []:
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
approval_id = str(event.get("approval_id") or "")
|
||||
if approval_id:
|
||||
self._approval_profile[approval_id] = profile
|
||||
self._approval_worker_key[approval_id] = worker_key
|
||||
clarify_id = str(event.get("clarify_id") or "")
|
||||
if clarify_id:
|
||||
self._clarify_profile[clarify_id] = profile
|
||||
self._clarify_worker_key[clarify_id] = worker_key
|
||||
request_id = str(event.get("request_id") or "")
|
||||
if event.get("event") == "bridge.compression.requested" and request_id:
|
||||
self._compression_profile[request_id] = profile
|
||||
self._compression_worker_key[request_id] = worker_key
|
||||
if event.get("event") in {"bridge.compression.completed", "bridge.compression.failed"} and request_id:
|
||||
self._compression_profile.pop(request_id, None)
|
||||
self._compression_worker_key.pop(request_id, None)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
@@ -2605,20 +2627,29 @@ class BridgeBroker:
|
||||
workers = list(self._workers.values())
|
||||
self._workers.clear()
|
||||
self._run_profile.clear()
|
||||
self._run_worker_key.clear()
|
||||
self._running_run_profile.clear()
|
||||
self._running_run_worker_key.clear()
|
||||
self._session_profile.clear()
|
||||
self._session_worker_key.clear()
|
||||
self._approval_profile.clear()
|
||||
self._approval_worker_key.clear()
|
||||
self._clarify_profile.clear()
|
||||
self._clarify_worker_key.clear()
|
||||
self._compression_profile.clear()
|
||||
self._compression_worker_key.clear()
|
||||
for worker in workers:
|
||||
worker.stop()
|
||||
|
||||
def _forward(self, profile: str, req: dict[str, Any]) -> dict[str, Any]:
|
||||
worker = self._worker_for_profile(profile)
|
||||
def _forward(self, profile: str, req: dict[str, Any], worker_key: str | None = None) -> dict[str, Any]:
|
||||
profile = self._normalize_profile(profile)
|
||||
key = self._normalize_worker_key(profile, worker_key)
|
||||
worker = self._worker_for_profile(profile, key)
|
||||
forwarded = dict(req)
|
||||
forwarded["profile"] = profile
|
||||
forwarded.pop("worker_key", None)
|
||||
resp = worker.request(forwarded)
|
||||
self._record_response_routes(profile, resp)
|
||||
self._record_response_routes(profile, key, resp)
|
||||
return resp
|
||||
|
||||
def handle(self, req: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -2629,15 +2660,16 @@ class BridgeBroker:
|
||||
if action == "ping":
|
||||
with self._lock:
|
||||
worker_details = {
|
||||
profile: {
|
||||
key: {
|
||||
"running": worker.running,
|
||||
"pid": worker.pid,
|
||||
"endpoint": worker.endpoint,
|
||||
"profile": getattr(worker, "profile", key),
|
||||
"last_used_at": worker.last_used_at,
|
||||
}
|
||||
for profile, worker in self._workers.items()
|
||||
for key, worker in self._workers.items()
|
||||
}
|
||||
workers = {profile: details["running"] for profile, details in worker_details.items()}
|
||||
workers = {key: details["running"] for key, details in worker_details.items()}
|
||||
sessions_by_profile: dict[str, int] = {}
|
||||
for profile in self._session_profile.values():
|
||||
sessions_by_profile[profile] = sessions_by_profile.get(profile, 0) + 1
|
||||
@@ -2664,29 +2696,32 @@ class BridgeBroker:
|
||||
|
||||
if action == "worker_ping":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
resp = self._forward(profile, {"action": "ping"})
|
||||
worker_key = self._normalize_worker_key(profile, req.get("worker_key"))
|
||||
resp = self._forward(profile, {"action": "ping"}, worker_key)
|
||||
resp["worker_profile"] = profile
|
||||
resp["worker_key"] = worker_key
|
||||
return resp
|
||||
|
||||
if action == "chat":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
|
||||
|
||||
if action == "context_estimate":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
|
||||
|
||||
if action in {"get_result", "get_output"}:
|
||||
profile = self._profile_for_run(str(req.get("run_id") or ""))
|
||||
return self._forward(profile, req)
|
||||
profile, worker_key = self._route_for_run(str(req.get("run_id") or ""))
|
||||
return self._forward(profile, req, worker_key)
|
||||
|
||||
if action in {"interrupt", "steer", "command", "goal_evaluate", "goal_pause", "status", "get_history", "destroy"}:
|
||||
session_id = str(req.get("session_id") or "")
|
||||
profile = self._profile_for_session(session_id, req.get("profile"))
|
||||
resp = self._forward(profile, req)
|
||||
profile, worker_key = self._route_for_session(session_id, req.get("profile"), req.get("worker_key") if "worker_key" in req else None)
|
||||
resp = self._forward(profile, req, worker_key)
|
||||
if action == "destroy":
|
||||
with self._lock:
|
||||
self._session_profile.pop(session_id, None)
|
||||
self._session_worker_key.pop(session_id, None)
|
||||
return resp
|
||||
|
||||
if action == "approval_respond":
|
||||
@@ -2695,9 +2730,10 @@ class BridgeBroker:
|
||||
raise ValueError("approval_id is required")
|
||||
with self._lock:
|
||||
profile = self._approval_profile.get(approval_id)
|
||||
worker_key = self._approval_worker_key.get(approval_id)
|
||||
if not profile:
|
||||
raise KeyError(f"unknown approval request: {approval_id}")
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, worker_key)
|
||||
|
||||
if action == "clarify_respond":
|
||||
clarify_id = str(req.get("clarify_id") or "").strip()
|
||||
@@ -2705,9 +2741,10 @@ class BridgeBroker:
|
||||
raise ValueError("clarify_id is required")
|
||||
with self._lock:
|
||||
profile = self._clarify_profile.get(clarify_id)
|
||||
worker_key = self._clarify_worker_key.get(clarify_id)
|
||||
if not profile:
|
||||
raise KeyError(f"unknown clarify request: {clarify_id}")
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, worker_key)
|
||||
|
||||
if action == "compression_respond":
|
||||
request_id = str(req.get("request_id") or "").strip()
|
||||
@@ -2715,20 +2752,27 @@ class BridgeBroker:
|
||||
raise ValueError("request_id is required")
|
||||
with self._lock:
|
||||
profile = self._compression_profile.get(request_id)
|
||||
worker_key = self._compression_worker_key.get(request_id)
|
||||
if not profile:
|
||||
raise KeyError(f"unknown compression request: {request_id}")
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, worker_key)
|
||||
|
||||
if action == "destroy_all":
|
||||
with self._lock:
|
||||
workers = list(self._workers.values())
|
||||
self._workers.clear()
|
||||
self._run_profile.clear()
|
||||
self._run_worker_key.clear()
|
||||
self._running_run_profile.clear()
|
||||
self._running_run_worker_key.clear()
|
||||
self._session_profile.clear()
|
||||
self._session_worker_key.clear()
|
||||
self._approval_profile.clear()
|
||||
self._approval_worker_key.clear()
|
||||
self._clarify_profile.clear()
|
||||
self._clarify_worker_key.clear()
|
||||
self._compression_profile.clear()
|
||||
self._compression_worker_key.clear()
|
||||
destroyed = 0
|
||||
for worker in workers:
|
||||
try:
|
||||
@@ -2744,40 +2788,56 @@ class BridgeBroker:
|
||||
if action == "destroy_profile":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
with self._lock:
|
||||
worker = self._workers.pop(profile, None)
|
||||
workers = [
|
||||
worker
|
||||
for key, worker in list(self._workers.items())
|
||||
if getattr(worker, "profile", key) == profile
|
||||
]
|
||||
for worker in workers:
|
||||
self._workers.pop(worker.key, None)
|
||||
self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile}
|
||||
self._run_worker_key = {key: value for key, value in self._run_worker_key.items() if key in self._run_profile}
|
||||
self._running_run_profile = {key: value for key, value in self._running_run_profile.items() if value != profile}
|
||||
self._running_run_worker_key = {key: value for key, value in self._running_run_worker_key.items() if key in self._running_run_profile}
|
||||
self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile}
|
||||
self._session_worker_key = {key: value for key, value in self._session_worker_key.items() if key in self._session_profile}
|
||||
self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile}
|
||||
self._approval_worker_key = {key: value for key, value in self._approval_worker_key.items() if key in self._approval_profile}
|
||||
self._clarify_profile = {key: value for key, value in self._clarify_profile.items() if value != profile}
|
||||
self._clarify_worker_key = {key: value for key, value in self._clarify_worker_key.items() if key in self._clarify_profile}
|
||||
self._compression_profile = {key: value for key, value in self._compression_profile.items() if value != profile}
|
||||
self._compression_worker_key = {key: value for key, value in self._compression_worker_key.items() if key in self._compression_profile}
|
||||
|
||||
if worker is None or not worker.running:
|
||||
if worker is not None:
|
||||
worker.stop()
|
||||
if not workers:
|
||||
return {"profile": profile, "destroyed": 0}
|
||||
|
||||
try:
|
||||
resp = worker.request({"action": "destroy_all"})
|
||||
destroyed = int(resp.get("destroyed") or 0)
|
||||
except Exception:
|
||||
destroyed = 0
|
||||
finally:
|
||||
worker.stop()
|
||||
destroyed = 0
|
||||
for worker in workers:
|
||||
if not worker.running:
|
||||
worker.stop()
|
||||
continue
|
||||
try:
|
||||
resp = worker.request({"action": "destroy_all"})
|
||||
destroyed += int(resp.get("destroyed") or 0)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
worker.stop()
|
||||
return {"profile": profile, "destroyed": destroyed}
|
||||
|
||||
if action == "list":
|
||||
sessions: list[Any] = []
|
||||
with self._lock:
|
||||
workers = list(self._workers.items())
|
||||
for profile, worker in workers:
|
||||
for key, worker in workers:
|
||||
if not worker.running:
|
||||
continue
|
||||
try:
|
||||
resp = worker.request({"action": "list"})
|
||||
for session in resp.get("sessions") or []:
|
||||
if isinstance(session, dict):
|
||||
session.setdefault("profile", profile)
|
||||
session.setdefault("profile", getattr(worker, "profile", key))
|
||||
session.setdefault("worker_key", key)
|
||||
sessions.append(session)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -2826,12 +2886,12 @@ class BridgeBroker:
|
||||
self._last_gc = now
|
||||
with self._lock:
|
||||
idle = [
|
||||
profile for profile, worker in self._workers.items()
|
||||
key for key, worker in self._workers.items()
|
||||
if worker.running and now - worker.last_used_at > self.IDLE_TIMEOUT_SECONDS
|
||||
]
|
||||
for profile in idle:
|
||||
for key in idle:
|
||||
with self._lock:
|
||||
worker = self._workers.pop(profile, None)
|
||||
worker = self._workers.pop(key, None)
|
||||
if worker:
|
||||
worker.stop()
|
||||
|
||||
|
||||
@@ -195,7 +195,8 @@ export async function buildCompressedHistory(
|
||||
emit: (event: string, payload: any) => void,
|
||||
sessionMap: Map<string, SessionState>,
|
||||
modelContext: { model?: string | null; provider?: string | null } = {},
|
||||
contextTokenEstimator?: (messages: ChatMessage[]) => Promise<number | null | undefined>,
|
||||
contextTokenEstimator?: (messages: ChatMessage[], messageTokens: number) => Promise<number | null | undefined>,
|
||||
currentInputTokens = 0,
|
||||
): Promise<ChatMessage[]> {
|
||||
try {
|
||||
let history = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||
@@ -213,14 +214,18 @@ export async function buildCompressedHistory(
|
||||
}
|
||||
const cState = getOrCreateSession(sessionMap, sessionId)
|
||||
const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
||||
const estimateFullContextTokens = async (messages: ChatMessage[], fallback: number) => {
|
||||
const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||
? Math.floor(currentInputTokens)
|
||||
: 0
|
||||
const estimateLocalContextTokens = async (messages: ChatMessage[], messageTokens: number) => {
|
||||
const localMessageTokens = Math.max(0, Math.floor(messageTokens))
|
||||
try {
|
||||
const estimate = await contextTokenEstimator?.(messages)
|
||||
const estimate = await contextTokenEstimator?.(messages, localMessageTokens)
|
||||
if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate)
|
||||
} catch (err) {
|
||||
logger.warn(err, '[context-compress] session=%s: full context token estimate failed; using message-only estimate', sessionId)
|
||||
logger.warn(err, '[context-compress] session=%s: fixed context token estimate failed; using message-only estimate', sessionId)
|
||||
}
|
||||
return fallback
|
||||
return localMessageTokens
|
||||
}
|
||||
const emitContextUsage = (contextTokens: number) => {
|
||||
cState.contextTokens = contextTokens
|
||||
@@ -236,10 +241,10 @@ export async function buildCompressedHistory(
|
||||
let totalTokens = messageOnlyTotalTokens
|
||||
|
||||
if (history.length === 0) {
|
||||
totalTokens = await estimateFullContextTokens([], 0)
|
||||
totalTokens = await estimateLocalContextTokens([], Math.max(currentRunInputTokens, messageOnlyTotalTokens))
|
||||
if (totalTokens > triggerTokens) {
|
||||
throw new ContextWindowTooSmallError(
|
||||
`Context window is too small: system prompt and tool schemas already use ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, or disable some tools.`,
|
||||
`Context window is too small: fixed prompt/tool overhead plus the current input uses ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, shorten the input, or disable some tools.`,
|
||||
)
|
||||
}
|
||||
if (totalTokens > 0) emitContextUsage(totalTokens)
|
||||
@@ -254,13 +259,15 @@ export async function buildCompressedHistory(
|
||||
sessionId, snapshot.lastMessageIndex, history.length)
|
||||
const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
const staleUsage = estimateUsageTokensFromMessages(staleHistory)
|
||||
totalTokens = await estimateFullContextTokens(staleHistory, staleUsage.inputTokens + staleUsage.outputTokens)
|
||||
const staleMessageTokens = staleUsage.inputTokens + staleUsage.outputTokens
|
||||
const staleRunMessageTokens = Math.max(staleMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
|
||||
totalTokens = await estimateLocalContextTokens(staleHistory, staleRunMessageTokens)
|
||||
emitContextUsage(totalTokens)
|
||||
logger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
messages: staleHistory.length,
|
||||
messageOnlyTokens: staleUsage.inputTokens + staleUsage.outputTokens,
|
||||
messageOnlyTokens: staleRunMessageTokens,
|
||||
fullContextTokens: totalTokens,
|
||||
triggerTokens,
|
||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||
@@ -272,13 +279,15 @@ export async function buildCompressedHistory(
|
||||
const newMessages = history.slice(snapshot.lastMessageIndex + 1)
|
||||
const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||
totalTokens = await estimateFullContextTokens(snapshotHistory, snapshotUsage.inputTokens + snapshotUsage.outputTokens)
|
||||
const snapshotMessageTokens = snapshotUsage.inputTokens + snapshotUsage.outputTokens
|
||||
const snapshotRunMessageTokens = Math.max(snapshotMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
|
||||
totalTokens = await estimateLocalContextTokens(snapshotHistory, snapshotRunMessageTokens)
|
||||
emitContextUsage(totalTokens)
|
||||
logger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
messages: snapshotHistory.length,
|
||||
messageOnlyTokens: snapshotUsage.inputTokens + snapshotUsage.outputTokens,
|
||||
messageOnlyTokens: snapshotRunMessageTokens,
|
||||
fullContextTokens: totalTokens,
|
||||
triggerTokens,
|
||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||
@@ -289,22 +298,25 @@ export async function buildCompressedHistory(
|
||||
if (totalTokens <= triggerTokens) {
|
||||
history = snapshotHistory
|
||||
} else {
|
||||
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
||||
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
|
||||
}
|
||||
} else if (snapshot && staleSnapshot) {
|
||||
if (totalTokens <= triggerTokens) {
|
||||
history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
} else {
|
||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
|
||||
}
|
||||
} else {
|
||||
totalTokens = await estimateFullContextTokens(history, totalTokens)
|
||||
const historyUsage = estimateUsageTokensFromMessages(history)
|
||||
const historyMessageTokens = historyUsage.inputTokens + historyUsage.outputTokens
|
||||
const runMessageTokens = Math.max(historyMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
|
||||
totalTokens = await estimateLocalContextTokens(history, runMessageTokens)
|
||||
emitContextUsage(totalTokens)
|
||||
logger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
messages: history.length,
|
||||
messageOnlyTokens: messageOnlyTotalTokens,
|
||||
messageOnlyTokens: runMessageTokens,
|
||||
fullContextTokens: totalTokens,
|
||||
triggerTokens,
|
||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||
@@ -318,7 +330,7 @@ export async function buildCompressedHistory(
|
||||
if (totalTokens <= triggerTokens) {
|
||||
logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens)
|
||||
} else {
|
||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,8 +354,12 @@ export async function compressHistory(
|
||||
sessionMap: Map<string, SessionState>,
|
||||
modelContext: { model?: string | null; provider?: string | null } = {},
|
||||
compressionConfig?: Partial<CompressorConfig>,
|
||||
currentInputTokens = 0,
|
||||
): Promise<ChatMessage[]> {
|
||||
const msgCount = newMessagesOnly ? newMessagesOnly.length : history.length
|
||||
const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||
? Math.floor(currentInputTokens)
|
||||
: 0
|
||||
pushState(sessionMap, sessionId, 'compression.started', {
|
||||
event: 'compression.started', message_count: msgCount, token_count: totalTokens,
|
||||
})
|
||||
@@ -353,14 +369,22 @@ export async function compressHistory(
|
||||
|
||||
try {
|
||||
const session = getSession(sessionId)
|
||||
const summarizerProfile = session?.profile || 'default'
|
||||
const compressor = new ChatContextCompressor({ config: compressionConfig })
|
||||
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
|
||||
profile: session?.profile,
|
||||
profile: summarizerProfile,
|
||||
model: modelContext.model || session?.model,
|
||||
provider: modelContext.provider || session?.provider,
|
||||
workerKey: `${summarizerProfile}:compression:${sessionId}`,
|
||||
})
|
||||
const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
||||
const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens
|
||||
const resultUsage = estimateUsageTokensFromMessages(result.messages)
|
||||
const resultMessageTokens = resultUsage.inputTokens + resultUsage.outputTokens
|
||||
const compressedRunMessageTokens = Math.max(
|
||||
compressedAfterTokens,
|
||||
resultMessageTokens + currentRunInputTokens,
|
||||
)
|
||||
const compressedMeta: any = {
|
||||
event: 'compression.completed' as const,
|
||||
compressed: result.meta.compressed,
|
||||
@@ -368,15 +392,15 @@ export async function compressHistory(
|
||||
totalMessages: result.meta.totalMessages,
|
||||
resultMessages: result.messages.length,
|
||||
beforeTokens: totalTokens,
|
||||
afterTokens: compressedAfterTokens,
|
||||
afterTokens: compressedRunMessageTokens,
|
||||
summaryTokens: result.meta.summaryTokenEstimate,
|
||||
verbatimCount: result.meta.verbatimCount,
|
||||
compressedStartIndex: result.meta.compressedStartIndex,
|
||||
}
|
||||
replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta)
|
||||
logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)',
|
||||
sessionId, result.messages.length, compressedAfterTokens, totalTokens)
|
||||
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedAfterTokens, afterTokens)
|
||||
sessionId, result.messages.length, compressedRunMessageTokens, totalTokens)
|
||||
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedRunMessageTokens, afterTokens)
|
||||
if (compressedContextTokens != null) {
|
||||
compressedMeta.contextTokens = compressedContextTokens
|
||||
}
|
||||
@@ -403,6 +427,7 @@ export async function compressHistory(
|
||||
resultMessages: msgCount,
|
||||
beforeTokens: totalTokens,
|
||||
afterTokens: totalTokens,
|
||||
contextTokens: totalTokens,
|
||||
summaryTokens: 0,
|
||||
verbatimCount: msgCount,
|
||||
compressedStartIndex: -1,
|
||||
@@ -458,10 +483,12 @@ export async function forceCompressBridgeHistory(
|
||||
}, '[chat-run-socket] bridge forced compression started')
|
||||
|
||||
const compressor = new ChatContextCompressor({ config: compressionConfig.compressor })
|
||||
const summarizerProfile = session?.profile || profile || 'default'
|
||||
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
|
||||
profile: session?.profile || profile,
|
||||
profile: summarizerProfile,
|
||||
model: session?.model,
|
||||
provider: session?.provider,
|
||||
workerKey: `${summarizerProfile}:compression:${sessionId}`,
|
||||
})
|
||||
const compressedMessages = result.messages.map(m => {
|
||||
const msg: any = { role: m.role, content: m.content }
|
||||
|
||||
@@ -18,7 +18,7 @@ import { convertHistoryFormat } from './message-format'
|
||||
import { readSseFrames } from './sse-utils'
|
||||
import { extractResponseText } from './response-utils'
|
||||
import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream'
|
||||
import { buildCompressedHistory, getOrCreateSession } from './compression'
|
||||
import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, getOrCreateSession } from './compression'
|
||||
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||
import { handleMessage } from './message-format'
|
||||
import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
|
||||
@@ -37,6 +37,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
||||
|
||||
let inputTokens: number
|
||||
let outputTokens: number
|
||||
let contextTokens: number | undefined
|
||||
const snapshot = getCompressionSnapshot(sid)
|
||||
if (snapshot && snapshot.lastMessageIndex >= 0 && snapshot.lastMessageIndex < messages.length) {
|
||||
const newMessages = messages.slice(snapshot.lastMessageIndex + 1)
|
||||
@@ -49,6 +50,20 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
||||
inputTokens = usage.inputTokens
|
||||
outputTokens = usage.outputTokens
|
||||
}
|
||||
try {
|
||||
const session = getSession(sid)
|
||||
const dbHistory = await buildDbHistory(sid, { excludeLastUser: false })
|
||||
const snapshotHistory = await buildSnapshotAwareHistory(
|
||||
sid,
|
||||
session?.profile || 'default',
|
||||
dbHistory,
|
||||
{ model: session?.model, provider: session?.provider },
|
||||
)
|
||||
const contextUsage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||
contextTokens = contextUsage.inputTokens + contextUsage.outputTokens
|
||||
} catch (err) {
|
||||
logger.warn(err, '[chat-run-socket] failed to calculate snapshot-aware context tokens for session %s', sid)
|
||||
}
|
||||
|
||||
logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length)
|
||||
return {
|
||||
@@ -57,6 +72,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
||||
events: [],
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
contextTokens,
|
||||
queue: [],
|
||||
}
|
||||
} catch (err) {
|
||||
|
||||
@@ -16,7 +16,6 @@ import {
|
||||
contextTokensWithCachedOverhead,
|
||||
estimateUsageTokensFromMessages,
|
||||
getCachedBridgeContextOverhead,
|
||||
updateContextTokenUsage,
|
||||
updateMessageContextTokenUsage,
|
||||
} from './usage'
|
||||
import {
|
||||
@@ -123,6 +122,66 @@ function cacheBridgeContext(state: SessionState, data: Record<string, unknown> |
|
||||
}
|
||||
}
|
||||
|
||||
function bridgeContextMatches(
|
||||
state: SessionState,
|
||||
expected: { profile: string; model?: string | null; provider?: string | null },
|
||||
): boolean {
|
||||
const context = state.bridgeContext
|
||||
if (!context) return false
|
||||
if (context.profile && context.profile !== expected.profile) return false
|
||||
if (expected.model && context.model && context.model !== expected.model) return false
|
||||
if (expected.provider && context.provider && context.provider !== expected.provider) return false
|
||||
return true
|
||||
}
|
||||
|
||||
async function ensureBridgeFixedContext(args: {
|
||||
sessionId: string
|
||||
profile: string
|
||||
model?: string | null
|
||||
provider?: string | null
|
||||
instructions: string
|
||||
state: SessionState
|
||||
bridge: AgentBridgeClient
|
||||
refresh?: boolean
|
||||
}): Promise<number | undefined> {
|
||||
const cached = bridgeContextMatches(args.state, args)
|
||||
? getCachedBridgeContextOverhead(args.state)
|
||||
: undefined
|
||||
if (!args.refresh && cached != null) return cached
|
||||
|
||||
try {
|
||||
const estimate = await args.bridge.contextEstimate(
|
||||
args.sessionId,
|
||||
[],
|
||||
args.instructions,
|
||||
args.profile,
|
||||
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
|
||||
)
|
||||
cacheBridgeContext(args.state, estimate)
|
||||
const fixedContextTokens = getCachedBridgeContextOverhead(args.state)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
toolCount: estimate.tool_count,
|
||||
systemPromptChars: estimate.system_prompt_chars,
|
||||
fixedContextTokens,
|
||||
}, '[chat-run-socket] fixed context estimate')
|
||||
return fixedContextTokens
|
||||
} catch (err) {
|
||||
bridgeLogger.warn({
|
||||
err: err instanceof Error ? { message: err.message, name: err.name } : err,
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
cachedFixedContextTokens: cached,
|
||||
}, '[chat-run-socket] fixed context estimate failed')
|
||||
return cached
|
||||
}
|
||||
}
|
||||
|
||||
export async function handleBridgeRun(
|
||||
nsp: ReturnType<Server['of']>,
|
||||
socket: Socket,
|
||||
@@ -195,6 +254,9 @@ export async function handleBridgeRun(
|
||||
|
||||
const displayInput = data.display_input === undefined ? input : data.display_input
|
||||
const inputStr = displayInput == null ? '' : contentBlocksToString(displayInput)
|
||||
const actualInputStr = contentBlocksToString(input)
|
||||
const currentInputUsage = estimateUsageTokensFromMessages([{ role: 'user', content: actualInputStr }])
|
||||
const currentInputTokens = currentInputUsage.inputTokens
|
||||
const shouldPersistUserMessage = !skipUserMessage && displayInput !== null
|
||||
const displayRole = data.display_role === 'command' ? 'command' : 'user'
|
||||
let messageId: number | string | undefined
|
||||
@@ -257,33 +319,32 @@ export async function handleBridgeRun(
|
||||
emit,
|
||||
sessionMap,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
async (messages) => {
|
||||
const cachedOverhead = getCachedBridgeContextOverhead(state)
|
||||
if (cachedOverhead != null) {
|
||||
const messageUsage = estimateUsageTokensFromMessages(messages)
|
||||
return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens
|
||||
}
|
||||
const estimate = await bridge.contextEstimate(
|
||||
session_id,
|
||||
messages,
|
||||
fullInstructions,
|
||||
async (_messages, localMessageTokens) => {
|
||||
const fixedContextTokens = await ensureBridgeFixedContext({
|
||||
sessionId: session_id,
|
||||
profile,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
)
|
||||
cacheBridgeContext(state, estimate)
|
||||
model: resolvedModel,
|
||||
provider: resolvedProvider,
|
||||
instructions: fullInstructions,
|
||||
state,
|
||||
bridge,
|
||||
refresh: true,
|
||||
})
|
||||
const contextTokens = fixedContextTokens == null
|
||||
? localMessageTokens
|
||||
: fixedContextTokens + localMessageTokens
|
||||
bridgeLogger.info({
|
||||
sessionId: session_id,
|
||||
profile,
|
||||
model: resolvedModel,
|
||||
provider: resolvedProvider,
|
||||
messages: estimate.message_count,
|
||||
toolCount: estimate.tool_count,
|
||||
systemPromptChars: estimate.system_prompt_chars,
|
||||
fixedContextTokens: estimate.fixed_context_tokens,
|
||||
fullContextTokens: estimate.token_count,
|
||||
}, '[chat-run-socket] full context estimate')
|
||||
return estimate.token_count
|
||||
fixedContextTokens,
|
||||
messageTokens: localMessageTokens,
|
||||
contextTokens,
|
||||
}, '[chat-run-socket] local context estimate')
|
||||
return contextTokens
|
||||
},
|
||||
currentInputTokens,
|
||||
)
|
||||
const bridgeHistory = history
|
||||
|
||||
@@ -349,6 +410,8 @@ export async function handleBridgeRun(
|
||||
dequeueNextQueuedRun,
|
||||
fullInstructions,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
currentInputTokens,
|
||||
shouldPersistUserMessage && displayRole === 'user',
|
||||
data.model_groups,
|
||||
)
|
||||
if (chunk.done) break
|
||||
@@ -417,61 +480,68 @@ async function refreshFinalContextUsage(args: {
|
||||
)
|
||||
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
|
||||
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
|
||||
if (getCachedBridgeContextOverhead(args.state) != null) {
|
||||
const contextTokens = updateMessageContextTokenUsage(
|
||||
args.sessionId,
|
||||
args.state,
|
||||
args.emit,
|
||||
finalMessageTokens,
|
||||
args.usage,
|
||||
)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
messages: finalHistory.length,
|
||||
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
|
||||
messageTokens: finalMessageTokens,
|
||||
fullContextTokens: contextTokens,
|
||||
}, '[chat-run-socket] final cached context estimate')
|
||||
return contextTokens
|
||||
}
|
||||
const estimate = await args.bridge.contextEstimate(
|
||||
await ensureBridgeFixedContext({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
instructions: args.instructions,
|
||||
state: args.state,
|
||||
bridge: args.bridge,
|
||||
})
|
||||
const contextTokens = updateMessageContextTokenUsage(
|
||||
args.sessionId,
|
||||
finalHistory,
|
||||
args.instructions,
|
||||
args.profile,
|
||||
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
|
||||
args.state,
|
||||
args.emit,
|
||||
finalMessageTokens,
|
||||
args.usage,
|
||||
)
|
||||
cacheBridgeContext(args.state, estimate)
|
||||
const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0
|
||||
? Math.floor(estimate.token_count)
|
||||
: undefined
|
||||
if (contextTokens == null) return args.state.contextTokens
|
||||
|
||||
updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
messages: estimate.message_count,
|
||||
toolCount: estimate.tool_count,
|
||||
systemPromptChars: estimate.system_prompt_chars,
|
||||
fullContextTokens: contextTokens,
|
||||
}, '[chat-run-socket] final full context estimate')
|
||||
messages: finalHistory.length,
|
||||
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
|
||||
messageTokens: finalMessageTokens,
|
||||
contextTokens,
|
||||
}, '[chat-run-socket] final local context estimate')
|
||||
return contextTokens
|
||||
} catch (err) {
|
||||
bridgeLogger.warn({
|
||||
err: err instanceof Error ? { message: err.message, name: err.name } : err,
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
}, '[chat-run-socket] final full context estimate failed')
|
||||
}, '[chat-run-socket] final local context estimate failed')
|
||||
return args.state.contextTokens
|
||||
}
|
||||
}
|
||||
|
||||
async function estimateSnapshotAwareMessageTokens(args: {
|
||||
sessionId: string
|
||||
profile: string
|
||||
model?: string | null
|
||||
provider?: string | null
|
||||
currentInputTokens?: number
|
||||
currentInputIncludedInDb?: boolean
|
||||
}): Promise<{ messageTokens: number; messages: number }> {
|
||||
const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
|
||||
const snapshotHistory = await buildSnapshotAwareHistory(
|
||||
args.sessionId,
|
||||
args.profile,
|
||||
dbHistory,
|
||||
{ model: args.model, provider: args.provider },
|
||||
)
|
||||
const usage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||
const extraInputTokens = args.currentInputIncludedInDb
|
||||
? 0
|
||||
: finiteToken(args.currentInputTokens) ?? 0
|
||||
return {
|
||||
messageTokens: usage.inputTokens + usage.outputTokens + extraInputTokens,
|
||||
messages: snapshotHistory.length,
|
||||
}
|
||||
}
|
||||
|
||||
async function applyBridgeChunkAsync(
|
||||
nsp: ReturnType<Server['of']>,
|
||||
socket: Socket,
|
||||
@@ -486,6 +556,8 @@ async function applyBridgeChunkAsync(
|
||||
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
|
||||
instructions: string,
|
||||
modelContext: { model?: string | null; provider?: string | null },
|
||||
currentInputTokens = 0,
|
||||
currentInputIncludedInDb = true,
|
||||
modelGroups?: RunModelGroup[],
|
||||
): Promise<void> {
|
||||
if (state.activeRunMarker !== runMarker) {
|
||||
@@ -505,11 +577,19 @@ async function applyBridgeChunkAsync(
|
||||
if (evType === 'bridge.context.ready') {
|
||||
cacheBridgeContext(state, ev)
|
||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||
const snapshotAware = await estimateSnapshotAwareMessageTokens({
|
||||
sessionId,
|
||||
profile,
|
||||
model: modelContext.model,
|
||||
provider: modelContext.provider,
|
||||
currentInputTokens,
|
||||
currentInputIncludedInDb,
|
||||
})
|
||||
updateMessageContextTokenUsage(
|
||||
sessionId,
|
||||
state,
|
||||
emit,
|
||||
usage.inputTokens + usage.outputTokens,
|
||||
snapshotAware.messageTokens,
|
||||
usage,
|
||||
)
|
||||
} else if (evType === 'tool.started') {
|
||||
@@ -646,17 +726,22 @@ async function applyBridgeChunkAsync(
|
||||
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
|
||||
const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens
|
||||
const tokenCount = typeof ev.approx_tokens === 'number' && Number.isFinite(ev.approx_tokens) && ev.approx_tokens > 0
|
||||
? ev.approx_tokens
|
||||
: messageOnlyTokens
|
||||
const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||
? Math.floor(currentInputTokens)
|
||||
: 0
|
||||
const runMessageTokens = messageOnlyTokens + runInputTokens
|
||||
const tokenCount = contextTokensWithCachedOverhead(state, runMessageTokens)
|
||||
bridgeLogger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
bridgeMessages: ev.message_count,
|
||||
dbMessages: bridgeHistory.length,
|
||||
messageOnlyTokens,
|
||||
fullContextTokens: tokenCount,
|
||||
source: typeof ev.approx_tokens === 'number' ? 'bridge' : 'message-only-fallback',
|
||||
currentInputTokens: runInputTokens,
|
||||
fixedContextTokens: state.bridgeContext?.fixedContextTokens,
|
||||
contextTokens: tokenCount,
|
||||
bridgeApproxTokens: ev.approx_tokens,
|
||||
source: 'local',
|
||||
}, '[chat-run-socket] bridge compression token estimate')
|
||||
const payload = {
|
||||
event: 'compression.started',
|
||||
@@ -674,7 +759,7 @@ async function applyBridgeChunkAsync(
|
||||
sessionId,
|
||||
profile,
|
||||
ev.messages as ChatMessage[],
|
||||
typeof ev.approx_tokens === 'number' ? ev.approx_tokens : undefined,
|
||||
tokenCount,
|
||||
)
|
||||
state.bridgeCompressionResults = state.bridgeCompressionResults || {}
|
||||
state.bridgeCompressionResults[String(ev.request_id)] = compressed
|
||||
@@ -689,11 +774,16 @@ async function applyBridgeChunkAsync(
|
||||
const compressionResult = ev.request_id
|
||||
? state.bridgeCompressionResults?.[String(ev.request_id)]
|
||||
: undefined
|
||||
const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens)
|
||||
const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
|
||||
const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null
|
||||
? contextTokensWithCachedOverhead(state, messageAfterTokens)
|
||||
: bridgeAfterContextTokens ?? messageAfterTokens
|
||||
const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||
? Math.floor(currentInputTokens)
|
||||
: 0
|
||||
const messageAfterTokensWithInput = messageAfterTokens != null
|
||||
? messageAfterTokens + runInputTokens
|
||||
: undefined
|
||||
const afterContextTokens = messageAfterTokensWithInput != null
|
||||
? contextTokensWithCachedOverhead(state, messageAfterTokensWithInput)
|
||||
: undefined
|
||||
const payload = {
|
||||
event: 'compression.completed',
|
||||
run_id: chunk.run_id,
|
||||
@@ -703,7 +793,7 @@ async function applyBridgeChunkAsync(
|
||||
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
|
||||
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
|
||||
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
|
||||
afterTokens: messageAfterTokens ?? bridgeAfterContextTokens,
|
||||
afterTokens: messageAfterTokensWithInput,
|
||||
contextTokens: afterContextTokens,
|
||||
summaryTokens: compressionResult?.summaryTokens,
|
||||
verbatimCount: compressionResult?.verbatimCount,
|
||||
@@ -716,10 +806,8 @@ async function applyBridgeChunkAsync(
|
||||
replaceState(sessionMap, sessionId, 'compression.completed', payload)
|
||||
emit('compression.completed', payload)
|
||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||
if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) {
|
||||
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage)
|
||||
} else {
|
||||
updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage)
|
||||
if (messageAfterTokensWithInput != null) {
|
||||
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokensWithInput, usage)
|
||||
}
|
||||
} else if (evType === 'bridge.compression.failed') {
|
||||
const payload = {
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
// @vitest-environment jsdom
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
|
||||
const chatApi = vi.hoisted(() => ({
|
||||
resumeSession: vi.fn(),
|
||||
registerSessionHandlers: vi.fn(),
|
||||
unregisterSessionHandlers: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/api/hermes/chat', () => ({
|
||||
startRunViaSocket: vi.fn(),
|
||||
resumeSession: chatApi.resumeSession,
|
||||
registerSessionHandlers: chatApi.registerSessionHandlers,
|
||||
unregisterSessionHandlers: chatApi.unregisterSessionHandlers,
|
||||
getChatRunSocket: vi.fn(() => ({ emit: vi.fn() })),
|
||||
respondToolApproval: vi.fn(),
|
||||
respondClarify: vi.fn(),
|
||||
onPeerUserMessage: vi.fn(() => vi.fn()),
|
||||
onSessionCommand: vi.fn(() => vi.fn()),
|
||||
}))
|
||||
|
||||
vi.mock('@/api/client', () => ({
|
||||
getActiveProfileName: () => 'default',
|
||||
}))
|
||||
|
||||
vi.mock('@/api/hermes/sessions', () => ({
|
||||
deleteSession: vi.fn(),
|
||||
fetchSession: vi.fn(),
|
||||
fetchSessions: vi.fn(),
|
||||
setSessionModel: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/api/hermes/download', () => ({
|
||||
getDownloadUrl: (_path: string, name: string) => `/download/${name}`,
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/completion-sound', () => ({
|
||||
primeCompletionSound: vi.fn(),
|
||||
playCompletionSound: vi.fn(),
|
||||
}))
|
||||
|
||||
import { useChatStore, type Session } from '@/stores/hermes/chat'
|
||||
|
||||
function makeSession(id: string): Session {
|
||||
return {
|
||||
id,
|
||||
title: id,
|
||||
messages: [],
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
}
|
||||
}
|
||||
|
||||
describe('chat store compression state', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks()
|
||||
setActivePinia(createPinia())
|
||||
chatApi.resumeSession.mockImplementation((sessionId: string, onResumed: (data: any) => void) => {
|
||||
onResumed({
|
||||
session_id: sessionId,
|
||||
messages: [],
|
||||
isWorking: sessionId === 'session-1',
|
||||
events: [],
|
||||
})
|
||||
return {} as any
|
||||
})
|
||||
})
|
||||
|
||||
it('does not show a background session compression indicator in the active session', async () => {
|
||||
const store = useChatStore()
|
||||
store.sessions = [makeSession('session-1'), makeSession('session-2')]
|
||||
|
||||
await store.switchSession('session-1')
|
||||
const handlers = chatApi.registerSessionHandlers.mock.calls.find(call => call[0] === 'session-1')?.[1]
|
||||
expect(handlers).toBeTruthy()
|
||||
|
||||
await store.switchSession('session-2')
|
||||
handlers.onCompressionStarted({
|
||||
event: 'compression.started',
|
||||
session_id: 'session-1',
|
||||
message_count: 6,
|
||||
token_count: 1234,
|
||||
})
|
||||
|
||||
expect(store.activeSessionId).toBe('session-2')
|
||||
expect(store.compressionState).toBeNull()
|
||||
|
||||
await store.switchSession('session-1')
|
||||
expect(store.compressionState).toEqual(expect.objectContaining({
|
||||
compressing: true,
|
||||
messageCount: 6,
|
||||
beforeTokens: 1234,
|
||||
}))
|
||||
})
|
||||
})
|
||||
@@ -292,9 +292,11 @@ except RuntimeError as exc:
|
||||
assert "already running" in str(exc)
|
||||
|
||||
class FakeWorker:
|
||||
def __init__(self, destroyed):
|
||||
def __init__(self, destroyed, profile="default", key="default"):
|
||||
self.running = True
|
||||
self.destroyed = destroyed
|
||||
self.profile = profile
|
||||
self.key = key
|
||||
self.requests = []
|
||||
self.stopped = False
|
||||
|
||||
@@ -310,28 +312,41 @@ broker = bridge.BridgeBroker("ipc:///tmp/unused.sock")
|
||||
profile_worker = FakeWorker(2)
|
||||
broker._workers["default"] = profile_worker
|
||||
broker._run_profile["run-session-a"] = "default"
|
||||
broker._run_worker_key["run-session-a"] = "default"
|
||||
broker._running_run_profile["run-session-a"] = "default"
|
||||
broker._running_run_worker_key["run-session-a"] = "default"
|
||||
broker._session_profile["session-a"] = "default"
|
||||
broker._session_worker_key["session-a"] = "default"
|
||||
broker._approval_profile["approval-a"] = "default"
|
||||
broker._approval_worker_key["approval-a"] = "default"
|
||||
broker._compression_profile["compression-a"] = "default"
|
||||
broker._compression_worker_key["compression-a"] = "default"
|
||||
|
||||
destroy_profile_result = broker.handle({"action": "destroy_profile", "profile": "default"})
|
||||
assert destroy_profile_result == {"profile": "default", "destroyed": 2}
|
||||
assert profile_worker.stopped
|
||||
assert "default" not in broker._workers
|
||||
assert broker._run_profile == {}
|
||||
assert broker._run_worker_key == {}
|
||||
assert broker._running_run_profile == {}
|
||||
assert broker._running_run_worker_key == {}
|
||||
assert broker._session_profile == {}
|
||||
assert broker._session_worker_key == {}
|
||||
assert broker._approval_profile == {}
|
||||
assert broker._approval_worker_key == {}
|
||||
assert broker._compression_profile == {}
|
||||
assert broker._compression_worker_key == {}
|
||||
|
||||
worker_a = FakeWorker(1)
|
||||
worker_b = FakeWorker(3)
|
||||
worker_a = FakeWorker(1, "default", "a")
|
||||
worker_b = FakeWorker(3, "work", "b")
|
||||
broker._workers["a"] = worker_a
|
||||
broker._workers["b"] = worker_b
|
||||
broker._run_profile["run-a"] = "a"
|
||||
broker._running_run_profile["run-a"] = "a"
|
||||
broker._session_profile["session-b"] = "b"
|
||||
broker._run_profile["run-a"] = "default"
|
||||
broker._run_worker_key["run-a"] = "a"
|
||||
broker._running_run_profile["run-a"] = "default"
|
||||
broker._running_run_worker_key["run-a"] = "a"
|
||||
broker._session_profile["session-b"] = "work"
|
||||
broker._session_worker_key["session-b"] = "b"
|
||||
|
||||
destroy_all_result = broker.handle({"action": "destroy_all"})
|
||||
assert destroy_all_result == {"destroyed": 4}
|
||||
@@ -339,8 +354,11 @@ assert worker_a.stopped
|
||||
assert worker_b.stopped
|
||||
assert broker._workers == {}
|
||||
assert broker._run_profile == {}
|
||||
assert broker._run_worker_key == {}
|
||||
assert broker._running_run_profile == {}
|
||||
assert broker._running_run_worker_key == {}
|
||||
assert broker._session_profile == {}
|
||||
assert broker._session_worker_key == {}
|
||||
`)
|
||||
})
|
||||
|
||||
@@ -372,6 +390,69 @@ assert resp["running_sessions_by_profile"] == {"default": 1}
|
||||
`)
|
||||
})
|
||||
|
||||
it('routes worker-keyed broker requests without stopping the worker on session destroy', () => {
|
||||
runPython(String.raw`
|
||||
${harness}
|
||||
|
||||
class RoutedWorker:
|
||||
running = True
|
||||
pid = 12345
|
||||
endpoint = "ipc:///tmp/worker.sock"
|
||||
last_used_at = 12.5
|
||||
|
||||
def __init__(self, profile, key):
|
||||
self.profile = profile
|
||||
self.key = key
|
||||
self.requests = []
|
||||
self.stopped = False
|
||||
|
||||
def request(self, req):
|
||||
self.requests.append(req)
|
||||
action = req.get("action")
|
||||
if action == "chat":
|
||||
return {"ok": True, "run_id": "run-compress", "session_id": req["session_id"], "status": "running"}
|
||||
if action == "get_output":
|
||||
return {"ok": True, "run_id": req["run_id"], "session_id": "compress-temp", "status": "complete", "done": True}
|
||||
if action == "destroy":
|
||||
return {"ok": True, "session_id": req["session_id"], "destroyed": True}
|
||||
raise AssertionError(f"unexpected action: {action}")
|
||||
|
||||
def stop(self):
|
||||
self.stopped = True
|
||||
|
||||
broker = bridge.BridgeBroker("ipc:///tmp/unused.sock")
|
||||
worker = RoutedWorker("default", "default:compression:session-a")
|
||||
broker._workers[worker.key] = worker
|
||||
|
||||
chat_resp = broker.handle({
|
||||
"action": "chat",
|
||||
"session_id": "compress-temp",
|
||||
"profile": "default",
|
||||
"worker_key": worker.key,
|
||||
"message": "summarize",
|
||||
})
|
||||
assert chat_resp["run_id"] == "run-compress"
|
||||
assert worker.requests[-1]["profile"] == "default"
|
||||
assert "worker_key" not in worker.requests[-1]
|
||||
|
||||
broker.handle({"action": "get_output", "run_id": "run-compress"})
|
||||
assert worker.requests[-1]["action"] == "get_output"
|
||||
|
||||
destroy_resp = broker.handle({
|
||||
"action": "destroy",
|
||||
"session_id": "compress-temp",
|
||||
"profile": "default",
|
||||
"worker_key": worker.key,
|
||||
})
|
||||
assert destroy_resp["destroyed"] is True
|
||||
assert worker.requests[-1]["action"] == "destroy"
|
||||
assert not worker.stopped
|
||||
assert worker.key in broker._workers
|
||||
assert "compress-temp" not in broker._session_profile
|
||||
assert "compress-temp" not in broker._session_worker_key
|
||||
`)
|
||||
})
|
||||
|
||||
it('restores approval env and clears handlers when a run fails', () => {
|
||||
runPython(String.raw`
|
||||
${harness}
|
||||
@@ -480,7 +561,7 @@ original_getpid = bridge.os.getpid
|
||||
try:
|
||||
bridge.subprocess.Popen = fake_popen
|
||||
bridge.os.getpid = lambda: 4242
|
||||
proc_worker = bridge.WorkerProcess("default", "ipc:///tmp/worker.sock", "/agent", "/home")
|
||||
proc_worker = bridge.WorkerProcess("default:compression:session-a", "default", "ipc:///tmp/worker.sock", "/agent", "/home")
|
||||
proc_worker._pipe_stderr = lambda: None
|
||||
proc_worker._wait_ready = lambda: None
|
||||
proc_worker.start()
|
||||
|
||||
@@ -153,6 +153,42 @@ describe('ChatContextCompressor', () => {
|
||||
expect(saveCompressionSnapshotMock).toHaveBeenCalledWith('s1', 'compressed summary', 6, 10)
|
||||
})
|
||||
|
||||
it('routes summarization through the provided worker key and destroys only the temporary agent session', async () => {
|
||||
const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor')
|
||||
const compressor = new ChatContextCompressor({
|
||||
config: { headMessageCount: 0, tailMessageCount: 1, summaryBudget: 1000 },
|
||||
})
|
||||
const messages = [
|
||||
{ role: 'user', content: 'old context' },
|
||||
{ role: 'assistant', content: 'old response' },
|
||||
{ role: 'user', content: 'tail' },
|
||||
]
|
||||
getCompressionSnapshotMock.mockReturnValue(null)
|
||||
bridgeRequestMock.mockResolvedValue({
|
||||
status: 'completed',
|
||||
result: { final_response: 'compressed summary' },
|
||||
})
|
||||
|
||||
await compressor.compress(messages, 'http://upstream', undefined, 's1', {
|
||||
profile: 'default',
|
||||
workerKey: 'default:compression:s1',
|
||||
})
|
||||
|
||||
expect(bridgeRequestMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
action: 'chat',
|
||||
profile: 'default',
|
||||
worker_key: 'default:compression:s1',
|
||||
wait: true,
|
||||
}), expect.any(Object))
|
||||
const compressSessionId = bridgeRequestMock.mock.calls[0][0].session_id
|
||||
expect(String(compressSessionId)).toMatch(/^compress_/)
|
||||
expect(bridgeDestroyMock).toHaveBeenCalledWith(
|
||||
compressSessionId,
|
||||
'default',
|
||||
'default:compression:s1',
|
||||
)
|
||||
})
|
||||
|
||||
it('does not pre-prune tool results before sending them to the summarizer', async () => {
|
||||
const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor')
|
||||
const compressor = new ChatContextCompressor({
|
||||
|
||||
@@ -127,8 +127,18 @@ describe('bridge run final context usage', () => {
|
||||
buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history)
|
||||
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 })
|
||||
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 })
|
||||
getCachedBridgeContextOverheadMock.mockReturnValue(undefined)
|
||||
contextTokensWithCachedOverheadMock.mockImplementation((_state: any, messageTokens: number) => messageTokens)
|
||||
getCachedBridgeContextOverheadMock.mockImplementation((state: any) => {
|
||||
const fixed = state?.bridgeContext?.fixedContextTokens
|
||||
return typeof fixed === 'number' ? fixed : undefined
|
||||
})
|
||||
contextTokensWithCachedOverheadMock.mockImplementation((state: any, messageTokens: number) => {
|
||||
const fixed = state?.bridgeContext?.fixedContextTokens
|
||||
return typeof fixed === 'number' ? fixed + messageTokens : messageTokens
|
||||
})
|
||||
updateMessageContextTokenUsageMock.mockImplementation((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => {
|
||||
const contextTokens = contextTokensWithCachedOverheadMock(state, messageTokens)
|
||||
return updateContextTokenUsageMock(sid, state, emit, contextTokens, usage)
|
||||
})
|
||||
})
|
||||
|
||||
it('refreshes full context tokens when a bridge run completes', async () => {
|
||||
@@ -141,6 +151,7 @@ describe('bridge run final context usage', () => {
|
||||
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
|
||||
contextEstimate: vi.fn().mockResolvedValue({
|
||||
token_count: 12345,
|
||||
fixed_context_tokens: 12327,
|
||||
message_count: 2,
|
||||
tool_count: 4,
|
||||
system_prompt_chars: 13,
|
||||
@@ -165,10 +176,7 @@ describe('bridge run final context usage', () => {
|
||||
|
||||
expect(bridge.contextEstimate).toHaveBeenCalledWith(
|
||||
'session-1',
|
||||
[
|
||||
{ role: 'user', content: 'hello' },
|
||||
{ role: 'assistant', content: 'done' },
|
||||
],
|
||||
[],
|
||||
expect.stringContaining('[Current Hermes profile: default]'),
|
||||
'default',
|
||||
{ model: 'gpt-test', provider: 'openai' },
|
||||
@@ -326,14 +334,22 @@ describe('bridge run final context usage', () => {
|
||||
const nsp = makeNamespace(emit)
|
||||
const socket = makeSocket()
|
||||
const state = makeState()
|
||||
state.bridgeContext = { fixedContextTokens: 20_000 }
|
||||
const sessionMap = new Map([['session-1', state]])
|
||||
getCachedBridgeContextOverheadMock.mockReturnValue(20_000)
|
||||
updateMessageContextTokenUsageMock.mockImplementation((sid: string, targetState: any, targetEmit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, targetState, targetEmit, 20_000 + messageTokens, usage))
|
||||
const bridge = {
|
||||
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
|
||||
contextEstimate: vi.fn(),
|
||||
streamOutput: vi.fn(async function* () {
|
||||
yield {
|
||||
run_id: 'run-1',
|
||||
done: false,
|
||||
status: 'running',
|
||||
events: [{
|
||||
event: 'bridge.context.ready',
|
||||
fixed_context_tokens: 20_000,
|
||||
system_prompt_tokens: 3_000,
|
||||
tool_tokens: 17_000,
|
||||
}],
|
||||
}
|
||||
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
|
||||
}),
|
||||
} as any
|
||||
@@ -365,6 +381,80 @@ describe('bridge run final context usage', () => {
|
||||
}))
|
||||
})
|
||||
|
||||
it('keeps bridge context ready updates on the snapshot-aware token baseline', async () => {
|
||||
const emit = vi.fn()
|
||||
const nsp = makeNamespace(emit)
|
||||
const socket = makeSocket()
|
||||
const state = makeState()
|
||||
const sessionMap = new Map([['session-1', state]])
|
||||
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 28_000, outputTokens: 0 })
|
||||
buildDbHistoryMock.mockResolvedValue([
|
||||
{ role: 'user', content: 'very large old context' },
|
||||
{ role: 'assistant', content: 'large old response' },
|
||||
{ role: 'user', content: 'hello' },
|
||||
])
|
||||
buildSnapshotAwareHistoryMock.mockResolvedValue([
|
||||
{ role: 'user', content: '[Previous context summary]\n\nsmall summary' },
|
||||
{ role: 'user', content: 'hello' },
|
||||
])
|
||||
estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => {
|
||||
if (messages?.[0]?.content?.includes('small summary')) {
|
||||
return { inputTokens: 9_000, outputTokens: 0 }
|
||||
}
|
||||
return { inputTokens: 28_000, outputTokens: 0 }
|
||||
})
|
||||
const bridge = {
|
||||
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
|
||||
contextEstimate: vi.fn(),
|
||||
streamOutput: vi.fn(async function* () {
|
||||
yield {
|
||||
run_id: 'run-1',
|
||||
done: false,
|
||||
status: 'running',
|
||||
events: [{
|
||||
event: 'bridge.context.ready',
|
||||
fixed_context_tokens: 10_000,
|
||||
system_prompt_tokens: 2_000,
|
||||
tool_tokens: 8_000,
|
||||
}],
|
||||
}
|
||||
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
|
||||
}),
|
||||
} as any
|
||||
|
||||
const { handleBridgeRun } = await import('../../packages/server/src/services/hermes/run-chat/handle-bridge-run')
|
||||
await handleBridgeRun(
|
||||
nsp,
|
||||
socket,
|
||||
{ input: 'hello', session_id: 'session-1' },
|
||||
'default',
|
||||
sessionMap,
|
||||
bridge,
|
||||
false,
|
||||
vi.fn(),
|
||||
vi.fn(),
|
||||
)
|
||||
|
||||
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
|
||||
'session-1',
|
||||
state,
|
||||
expect.any(Function),
|
||||
9_000,
|
||||
{ inputTokens: 28_000, outputTokens: 0 },
|
||||
)
|
||||
expect(updateMessageContextTokenUsageMock).not.toHaveBeenCalledWith(
|
||||
'session-1',
|
||||
state,
|
||||
expect.any(Function),
|
||||
28_000,
|
||||
{ inputTokens: 28_000, outputTokens: 0 },
|
||||
)
|
||||
expect(state.contextTokens).toBe(19_000)
|
||||
expect(emit).toHaveBeenCalledWith('run.completed', expect.objectContaining({
|
||||
contextTokens: 19_000,
|
||||
}))
|
||||
})
|
||||
|
||||
it('persists pending tool marker text before a bridge run completes', async () => {
|
||||
const emit = vi.fn()
|
||||
const nsp = makeNamespace(emit)
|
||||
@@ -502,6 +592,7 @@ describe('bridge run final context usage', () => {
|
||||
chat: vi.fn().mockRejectedValue(new Error('bridge timeout')),
|
||||
contextEstimate: vi.fn().mockResolvedValue({
|
||||
token_count: 54321,
|
||||
fixed_context_tokens: 54303,
|
||||
message_count: 1,
|
||||
tool_count: 4,
|
||||
system_prompt_chars: 13,
|
||||
|
||||
@@ -175,7 +175,7 @@ describe('run chat compression trigger', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('uses full context estimates for compression threshold decisions', async () => {
|
||||
it('uses local context estimates for compression threshold decisions', async () => {
|
||||
const messages = Array.from({ length: 10 }, (_, index) => ({
|
||||
id: index + 1,
|
||||
session_id: 'session-1',
|
||||
@@ -191,7 +191,7 @@ describe('run chat compression trigger', () => {
|
||||
getSessionDetailMock.mockReturnValue({ messages })
|
||||
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 1_000, outputTokens: 0 })
|
||||
compressorCompressMock.mockResolvedValue({
|
||||
messages: [{ role: 'user', content: 'compressed by full context estimate' }],
|
||||
messages: [{ role: 'user', content: 'compressed by local context estimate' }],
|
||||
meta: {
|
||||
compressed: true,
|
||||
llmCompressed: true,
|
||||
@@ -215,7 +215,7 @@ describe('run chat compression trigger', () => {
|
||||
vi.fn(async () => 120_000),
|
||||
)
|
||||
|
||||
expect(history).toEqual([{ role: 'user', content: 'compressed by full context estimate' }])
|
||||
expect(history).toEqual([{ role: 'user', content: 'compressed by local context estimate' }])
|
||||
expect(compressorCompressMock).toHaveBeenCalledTimes(1)
|
||||
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
|
||||
'session-1',
|
||||
@@ -226,7 +226,7 @@ describe('run chat compression trigger', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('emits full context token usage when the full estimate is under threshold', async () => {
|
||||
it('emits local context token usage when the local estimate is under threshold', async () => {
|
||||
const messages = Array.from({ length: 10 }, (_, index) => ({
|
||||
id: index + 1,
|
||||
session_id: 'session-1',
|
||||
@@ -257,7 +257,10 @@ describe('run chat compression trigger', () => {
|
||||
)
|
||||
|
||||
expect(history).toHaveLength(9)
|
||||
expect(contextTokenEstimator).toHaveBeenCalledWith(expect.arrayContaining([{ role: 'user', content: 'message 0' }]))
|
||||
expect(contextTokenEstimator).toHaveBeenCalledWith(
|
||||
expect.arrayContaining([{ role: 'user', content: 'message 0' }]),
|
||||
1_900,
|
||||
)
|
||||
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
|
||||
event: 'usage.updated',
|
||||
session_id: 'session-1',
|
||||
@@ -268,6 +271,108 @@ describe('run chat compression trigger', () => {
|
||||
expect(compressorCompressMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('includes current input tokens when estimating snapshot-aware context', async () => {
|
||||
const messages = Array.from({ length: 10 }, (_, index) => ({
|
||||
id: index + 1,
|
||||
session_id: 'session-1',
|
||||
role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant',
|
||||
content: `message ${index}`,
|
||||
timestamp: index + 1,
|
||||
tool_call_id: null,
|
||||
tool_calls: null,
|
||||
tool_name: null,
|
||||
finish_reason: null,
|
||||
reasoning_content: null,
|
||||
}))
|
||||
getSessionDetailMock.mockReturnValue({ messages })
|
||||
getCompressionSnapshotMock.mockReturnValue({
|
||||
summary: 'previous summary',
|
||||
lastMessageIndex: 4,
|
||||
messageCountAtTime: 5,
|
||||
})
|
||||
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 10, outputTokens: 0 })
|
||||
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 1_000, outputTokens: 0 })
|
||||
const emit = vi.fn()
|
||||
const contextTokenEstimator = vi.fn(async (_messages, messageTokens: number) => 20_000 + messageTokens)
|
||||
|
||||
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
|
||||
await buildCompressedHistory(
|
||||
'session-1',
|
||||
'default',
|
||||
'http://upstream',
|
||||
undefined,
|
||||
emit,
|
||||
new Map(),
|
||||
{},
|
||||
contextTokenEstimator,
|
||||
700,
|
||||
)
|
||||
|
||||
expect(contextTokenEstimator).toHaveBeenCalledWith(expect.any(Array), 1_700)
|
||||
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
|
||||
contextTokens: 21_700,
|
||||
}))
|
||||
expect(compressorCompressMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('keeps current input tokens in the compression completed context total', async () => {
|
||||
const messages = Array.from({ length: 10 }, (_, index) => ({
|
||||
id: index + 1,
|
||||
session_id: 'session-1',
|
||||
role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant',
|
||||
content: `message ${index}`,
|
||||
timestamp: index + 1,
|
||||
tool_call_id: null,
|
||||
tool_calls: null,
|
||||
tool_name: null,
|
||||
finish_reason: null,
|
||||
reasoning_content: null,
|
||||
}))
|
||||
getSessionDetailMock.mockReturnValue({ messages })
|
||||
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 100, outputTokens: 0 })
|
||||
estimateUsageTokensFromMessagesMock.mockImplementation((items: any[]) => {
|
||||
if (items?.[0]?.content === 'compressed result') return { inputTokens: 1_000, outputTokens: 0 }
|
||||
return { inputTokens: 100, outputTokens: 0 }
|
||||
})
|
||||
compressorCompressMock.mockResolvedValue({
|
||||
messages: [{ role: 'user', content: 'compressed result' }],
|
||||
meta: {
|
||||
compressed: true,
|
||||
llmCompressed: true,
|
||||
totalMessages: 9,
|
||||
summaryTokenEstimate: 1,
|
||||
verbatimCount: 0,
|
||||
compressedStartIndex: 0,
|
||||
},
|
||||
})
|
||||
const emit = vi.fn()
|
||||
|
||||
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
|
||||
await buildCompressedHistory(
|
||||
'session-1',
|
||||
'default',
|
||||
'http://upstream',
|
||||
undefined,
|
||||
emit,
|
||||
new Map(),
|
||||
{},
|
||||
vi.fn(async () => 120_000),
|
||||
700,
|
||||
)
|
||||
|
||||
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
|
||||
'session-1',
|
||||
expect.any(Object),
|
||||
emit,
|
||||
1_700,
|
||||
{ inputTokens: 100, outputTokens: 0 },
|
||||
)
|
||||
expect(emit).toHaveBeenCalledWith('compression.completed', expect.objectContaining({
|
||||
afterTokens: 1_700,
|
||||
contextTokens: 1_700,
|
||||
}))
|
||||
})
|
||||
|
||||
it('throws when fixed prompt and tool schemas exceed threshold before any history exists', async () => {
|
||||
getSessionDetailMock.mockReturnValue({ messages: [] })
|
||||
const emit = vi.fn()
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const getSessionMock = vi.fn()
|
||||
const getSessionDetailPaginatedMock = vi.fn()
|
||||
const getCompressionSnapshotMock = vi.fn()
|
||||
const estimateUsageTokensFromMessagesMock = vi.fn()
|
||||
const buildDbHistoryMock = vi.fn()
|
||||
const buildSnapshotAwareHistoryMock = vi.fn()
|
||||
|
||||
vi.mock('../../packages/server/src/db/hermes/session-store', () => ({
|
||||
getSession: getSessionMock,
|
||||
createSession: vi.fn(),
|
||||
addMessage: vi.fn(),
|
||||
updateSessionStats: vi.fn(),
|
||||
getSessionDetailPaginated: getSessionDetailPaginatedMock,
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/db/hermes/usage-store', () => ({
|
||||
updateUsage: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/db/hermes/compression-snapshot', () => ({
|
||||
getCompressionSnapshot: getCompressionSnapshotMock,
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/lib/context-compressor', () => ({
|
||||
SUMMARY_PREFIX: '[Previous context summary]',
|
||||
countTokens: vi.fn(() => 0),
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/logger', () => ({
|
||||
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() },
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () => ({
|
||||
buildCompressedHistory: vi.fn(),
|
||||
buildDbHistory: buildDbHistoryMock,
|
||||
buildSnapshotAwareHistory: buildSnapshotAwareHistoryMock,
|
||||
getOrCreateSession: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({
|
||||
calcAndUpdateUsage: vi.fn(),
|
||||
estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock,
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/message-format', () => ({
|
||||
convertHistoryFormat: vi.fn((messages: any[]) => messages),
|
||||
handleMessage: vi.fn((messages: any[]) => messages),
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/content-blocks', () => ({
|
||||
contentBlocksToString: vi.fn((value: any) => String(value || '')),
|
||||
extractTextForPreview: vi.fn((value: any) => String(value || '')),
|
||||
isContentBlockArray: vi.fn(() => false),
|
||||
convertContentBlocks: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/lib/llm-prompt', () => ({
|
||||
getSystemPrompt: vi.fn(() => 'system prompt'),
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/sse-utils', () => ({
|
||||
readSseFrames: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/response-utils', () => ({
|
||||
extractResponseText: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../../packages/server/src/services/hermes/run-chat/response-stream', () => ({
|
||||
applyResponseStreamEvent: vi.fn(),
|
||||
flushResponseRunToDb: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('loadSessionStateFromDb', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
getSessionMock.mockReturnValue({
|
||||
id: 'session-1',
|
||||
profile: 'default',
|
||||
model: 'gpt-test',
|
||||
provider: 'openai',
|
||||
})
|
||||
getSessionDetailPaginatedMock.mockReturnValue({
|
||||
messages: [
|
||||
{ role: 'user', content: 'old large context' },
|
||||
{ role: 'assistant', content: 'old large answer' },
|
||||
{ role: 'user', content: 'new tail' },
|
||||
],
|
||||
})
|
||||
getCompressionSnapshotMock.mockReturnValue({
|
||||
summary: 'small summary',
|
||||
lastMessageIndex: 0,
|
||||
messageCountAtTime: 1,
|
||||
})
|
||||
buildDbHistoryMock.mockResolvedValue([
|
||||
{ role: 'user', content: 'old large context' },
|
||||
{ role: 'assistant', content: 'old large answer' },
|
||||
{ role: 'user', content: 'new tail' },
|
||||
])
|
||||
buildSnapshotAwareHistoryMock.mockResolvedValue([
|
||||
{ role: 'user', content: '[Previous context summary]\n\nsmall summary' },
|
||||
{ role: 'user', content: 'new tail' },
|
||||
])
|
||||
estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => {
|
||||
if (messages?.[0]?.content?.includes('small summary')) {
|
||||
return { inputTokens: 9_000, outputTokens: 0 }
|
||||
}
|
||||
return { inputTokens: 28_000, outputTokens: 0 }
|
||||
})
|
||||
})
|
||||
|
||||
it('hydrates contextTokens from the same snapshot-aware history used for bridge runs', async () => {
|
||||
const { loadSessionStateFromDb } = await import('../../packages/server/src/services/hermes/run-chat/handle-api-run')
|
||||
|
||||
const state = await loadSessionStateFromDb('session-1', new Map())
|
||||
|
||||
expect(buildSnapshotAwareHistoryMock).toHaveBeenCalledWith(
|
||||
'session-1',
|
||||
'default',
|
||||
expect.any(Array),
|
||||
{ model: 'gpt-test', provider: 'openai' },
|
||||
)
|
||||
expect(state.inputTokens).toBe(28_000)
|
||||
expect(state.outputTokens).toBe(0)
|
||||
expect(state.contextTokens).toBe(9_000)
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user