fix context token resume (#1039)
This commit is contained in:
@@ -85,6 +85,15 @@ export interface Session {
|
|||||||
workspace?: string | null
|
workspace?: string | null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface CompressionState {
|
||||||
|
compressing: boolean
|
||||||
|
messageCount: number
|
||||||
|
beforeTokens: number
|
||||||
|
afterTokens: number
|
||||||
|
compressed: boolean | null
|
||||||
|
error?: string
|
||||||
|
}
|
||||||
|
|
||||||
function uid(): string {
|
function uid(): string {
|
||||||
return Date.now().toString(36) + Math.random().toString(36).slice(2, 8)
|
return Date.now().toString(36) + Math.random().toString(36).slice(2, 8)
|
||||||
}
|
}
|
||||||
@@ -272,6 +281,8 @@ function mapHermesSession(s: SessionSummary): Session {
|
|||||||
model: s.model,
|
model: s.model,
|
||||||
provider: s.provider || (s as any).billing_provider || '',
|
provider: s.provider || (s as any).billing_provider || '',
|
||||||
messageCount: s.message_count,
|
messageCount: s.message_count,
|
||||||
|
inputTokens: s.input_tokens,
|
||||||
|
outputTokens: s.output_tokens,
|
||||||
endedAt: s.ended_at != null ? Math.round(s.ended_at * 1000) : null,
|
endedAt: s.ended_at != null ? Math.round(s.ended_at * 1000) : null,
|
||||||
lastActiveAt: s.last_active != null ? Math.round(s.last_active * 1000) : undefined,
|
lastActiveAt: s.last_active != null ? Math.round(s.last_active * 1000) : undefined,
|
||||||
workspace: s.workspace || null,
|
workspace: s.workspace || null,
|
||||||
@@ -405,18 +416,20 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
const isLoadingMessages = ref(false)
|
const isLoadingMessages = ref(false)
|
||||||
const isRunActive = computed(() => isStreaming.value)
|
const isRunActive = computed(() => isStreaming.value)
|
||||||
|
|
||||||
// Compression state
|
// Compression state is scoped per session because sockets can stay joined to
|
||||||
const compressionState = ref<{
|
// background sessions while another chat is active.
|
||||||
compressing: boolean
|
const compressionStates = ref<Map<string, CompressionState>>(new Map())
|
||||||
messageCount: number
|
const compressionState = computed<CompressionState | null>(() => {
|
||||||
beforeTokens: number
|
const sid = activeSessionId.value
|
||||||
afterTokens: number
|
return sid ? compressionStates.value.get(sid) || null : null
|
||||||
compressed: boolean | null
|
})
|
||||||
error?: string
|
|
||||||
} | null>(null)
|
|
||||||
|
|
||||||
function setCompressionState(state: typeof compressionState.value) {
|
function setCompressionState(sessionId: string | null | undefined, state: CompressionState | null) {
|
||||||
compressionState.value = state
|
if (!sessionId) return
|
||||||
|
const next = new Map(compressionStates.value)
|
||||||
|
if (state) next.set(sessionId, state)
|
||||||
|
else next.delete(sessionId)
|
||||||
|
compressionStates.value = next
|
||||||
}
|
}
|
||||||
|
|
||||||
const abortState = ref<{
|
const abortState = ref<{
|
||||||
@@ -438,11 +451,12 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function clearActiveSession() {
|
function clearActiveSession() {
|
||||||
|
const sid = activeSessionId.value
|
||||||
activeSessionId.value = null
|
activeSessionId.value = null
|
||||||
activeSession.value = null
|
activeSession.value = null
|
||||||
focusMessageId.value = null
|
focusMessageId.value = null
|
||||||
setAbortState(null)
|
setAbortState(null)
|
||||||
setCompressionState(null)
|
setCompressionState(sid, null)
|
||||||
removeItem(storageKey())
|
removeItem(storageKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -453,10 +467,14 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
const fresh = list.map(mapHermesSession)
|
const fresh = list.map(mapHermesSession)
|
||||||
// Preserve already-loaded messages for sessions that are still present,
|
// Preserve already-loaded messages for sessions that are still present,
|
||||||
// so we don't blow away the active session's messages on refresh.
|
// so we don't blow away the active session's messages on refresh.
|
||||||
const msgsByIdBefore = new Map(sessions.value.map(s => [s.id, s.messages]))
|
const runtimeByIdBefore = new Map(sessions.value.map(s => [s.id, {
|
||||||
|
messages: s.messages,
|
||||||
|
contextTokens: s.contextTokens,
|
||||||
|
}]))
|
||||||
for (const s of fresh) {
|
for (const s of fresh) {
|
||||||
const prev = msgsByIdBefore.get(s.id)
|
const prev = runtimeByIdBefore.get(s.id)
|
||||||
if (prev && prev.length) s.messages = prev
|
if (prev?.messages?.length) s.messages = prev.messages
|
||||||
|
if (prev?.contextTokens != null) s.contextTokens = prev.contextTokens
|
||||||
}
|
}
|
||||||
sessions.value = fresh
|
sessions.value = fresh
|
||||||
|
|
||||||
@@ -594,6 +612,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
} else if (!data.isWorking) {
|
} else if (!data.isWorking) {
|
||||||
setAbortState(null)
|
setAbortState(null)
|
||||||
}
|
}
|
||||||
|
if (!data.isWorking) setCompressionState(sessionId, null)
|
||||||
if (data.inputTokens != null) target.inputTokens = data.inputTokens
|
if (data.inputTokens != null) target.inputTokens = data.inputTokens
|
||||||
if (data.outputTokens != null) target.outputTokens = data.outputTokens
|
if (data.outputTokens != null) target.outputTokens = data.outputTokens
|
||||||
if ((data as any).contextTokens != null) target.contextTokens = (data as any).contextTokens
|
if ((data as any).contextTokens != null) target.contextTokens = (data as any).contextTokens
|
||||||
@@ -613,7 +632,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
for (const evt of data.events) {
|
for (const evt of data.events) {
|
||||||
const e = evt.data as any
|
const e = evt.data as any
|
||||||
if (e.event === 'compression.started') {
|
if (e.event === 'compression.started') {
|
||||||
setCompressionState({
|
setCompressionState(sessionId, {
|
||||||
compressing: true,
|
compressing: true,
|
||||||
messageCount: e.message_count || 0,
|
messageCount: e.message_count || 0,
|
||||||
beforeTokens: e.token_count || 0,
|
beforeTokens: e.token_count || 0,
|
||||||
@@ -622,7 +641,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
})
|
})
|
||||||
} else if (e.event === 'compression.completed') {
|
} else if (e.event === 'compression.completed') {
|
||||||
const afterTokens = e.contextTokens || e.afterTokens || 0
|
const afterTokens = e.contextTokens || e.afterTokens || 0
|
||||||
setCompressionState({
|
setCompressionState(sessionId, {
|
||||||
compressing: false,
|
compressing: false,
|
||||||
messageCount: e.totalMessages || 0,
|
messageCount: e.totalMessages || 0,
|
||||||
beforeTokens: e.beforeTokens || 0,
|
beforeTokens: e.beforeTokens || 0,
|
||||||
@@ -1385,6 +1404,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
} else if (!data.isWorking) {
|
} else if (!data.isWorking) {
|
||||||
setAbortState(null)
|
setAbortState(null)
|
||||||
}
|
}
|
||||||
|
if (!data.isWorking) setCompressionState(sid, null)
|
||||||
|
|
||||||
if (data.inputTokens != null) target.inputTokens = data.inputTokens
|
if (data.inputTokens != null) target.inputTokens = data.inputTokens
|
||||||
if (data.outputTokens != null) target.outputTokens = data.outputTokens
|
if (data.outputTokens != null) target.outputTokens = data.outputTokens
|
||||||
@@ -1407,7 +1427,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
const e = evt.data as RunEvent
|
const e = evt.data as RunEvent
|
||||||
switch (e.event) {
|
switch (e.event) {
|
||||||
case 'compression.started':
|
case 'compression.started':
|
||||||
setCompressionState({
|
setCompressionState(sid, {
|
||||||
compressing: true,
|
compressing: true,
|
||||||
messageCount: (e as any).message_count || 0,
|
messageCount: (e as any).message_count || 0,
|
||||||
beforeTokens: (e as any).token_count || 0,
|
beforeTokens: (e as any).token_count || 0,
|
||||||
@@ -1417,7 +1437,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
break
|
break
|
||||||
case 'compression.completed': {
|
case 'compression.completed': {
|
||||||
const afterTokens = (e as any).contextTokens || (e as any).afterTokens || 0
|
const afterTokens = (e as any).contextTokens || (e as any).afterTokens || 0
|
||||||
setCompressionState({
|
setCompressionState(sid, {
|
||||||
compressing: false,
|
compressing: false,
|
||||||
messageCount: (e as any).totalMessages || 0,
|
messageCount: (e as any).totalMessages || 0,
|
||||||
beforeTokens: (e as any).beforeTokens || 0,
|
beforeTokens: (e as any).beforeTokens || 0,
|
||||||
@@ -1474,7 +1494,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
case 'run.started':
|
case 'run.started':
|
||||||
clearAgentEventMessages(sid)
|
clearAgentEventMessages(sid)
|
||||||
setAbortState(null)
|
setAbortState(null)
|
||||||
setCompressionState(null)
|
setCompressionState(sid, null)
|
||||||
runProducedAssistantText = false
|
runProducedAssistantText = false
|
||||||
runHadToolActivity = false
|
runHadToolActivity = false
|
||||||
closeStreamingAssistant()
|
closeStreamingAssistant()
|
||||||
@@ -1502,7 +1522,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case 'compression.started': {
|
case 'compression.started': {
|
||||||
setCompressionState({
|
setCompressionState(sid, {
|
||||||
compressing: true,
|
compressing: true,
|
||||||
messageCount: (evt as any).message_count || 0,
|
messageCount: (evt as any).message_count || 0,
|
||||||
beforeTokens: (evt as any).token_count || 0,
|
beforeTokens: (evt as any).token_count || 0,
|
||||||
@@ -1514,7 +1534,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
|
|
||||||
case 'compression.completed': {
|
case 'compression.completed': {
|
||||||
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
|
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
|
||||||
setCompressionState({
|
setCompressionState(sid, {
|
||||||
compressing: false,
|
compressing: false,
|
||||||
messageCount: (evt as any).totalMessages || 0,
|
messageCount: (evt as any).totalMessages || 0,
|
||||||
beforeTokens: (evt as any).beforeTokens || 0,
|
beforeTokens: (evt as any).beforeTokens || 0,
|
||||||
@@ -1528,8 +1548,9 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
}
|
}
|
||||||
// Auto-clear after 5s
|
// Auto-clear after 5s
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
if (compressionState.value && !compressionState.value.compressing) {
|
const state = compressionStates.value.get(sid)
|
||||||
setCompressionState(null)
|
if (state && !state.compressing) {
|
||||||
|
setCompressionState(sid, null)
|
||||||
}
|
}
|
||||||
}, 5000)
|
}, 5000)
|
||||||
break
|
break
|
||||||
@@ -1966,7 +1987,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
case 'run.started':
|
case 'run.started':
|
||||||
clearAgentEventMessages(sid)
|
clearAgentEventMessages(sid)
|
||||||
setAbortState(null)
|
setAbortState(null)
|
||||||
setCompressionState(null)
|
setCompressionState(sid, null)
|
||||||
runProducedAssistantText = false
|
runProducedAssistantText = false
|
||||||
runHadToolActivity = false
|
runHadToolActivity = false
|
||||||
closeStreamingAssistant()
|
closeStreamingAssistant()
|
||||||
@@ -1979,7 +2000,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
break
|
break
|
||||||
|
|
||||||
case 'compression.started': {
|
case 'compression.started': {
|
||||||
setCompressionState({
|
setCompressionState(sid, {
|
||||||
compressing: true,
|
compressing: true,
|
||||||
messageCount: (evt as any).message_count || 0,
|
messageCount: (evt as any).message_count || 0,
|
||||||
beforeTokens: (evt as any).token_count || 0,
|
beforeTokens: (evt as any).token_count || 0,
|
||||||
@@ -1991,7 +2012,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
|
|
||||||
case 'compression.completed': {
|
case 'compression.completed': {
|
||||||
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
|
const afterTokens = (evt as any).contextTokens || (evt as any).afterTokens || 0
|
||||||
setCompressionState({
|
setCompressionState(sid, {
|
||||||
compressing: false,
|
compressing: false,
|
||||||
messageCount: (evt as any).totalMessages || 0,
|
messageCount: (evt as any).totalMessages || 0,
|
||||||
beforeTokens: (evt as any).beforeTokens || 0,
|
beforeTokens: (evt as any).beforeTokens || 0,
|
||||||
@@ -2004,8 +2025,9 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
if (target) target.contextTokens = (evt as any).contextTokens
|
if (target) target.contextTokens = (evt as any).contextTokens
|
||||||
}
|
}
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
if (compressionState.value && !compressionState.value.compressing) {
|
const state = compressionStates.value.get(sid)
|
||||||
setCompressionState(null)
|
if (state && !state.compressing) {
|
||||||
|
setCompressionState(sid, null)
|
||||||
}
|
}
|
||||||
}, 5000)
|
}, 5000)
|
||||||
break
|
break
|
||||||
@@ -2461,6 +2483,7 @@ export const useChatStore = defineStore('chat', () => {
|
|||||||
} else if (!data.isWorking) {
|
} else if (!data.isWorking) {
|
||||||
setAbortState(null)
|
setAbortState(null)
|
||||||
}
|
}
|
||||||
|
if (!data.isWorking) setCompressionState(sid, null)
|
||||||
if (data.messages?.length && activeSession.value) {
|
if (data.messages?.length && activeSession.value) {
|
||||||
activeSession.value.messages = mapHermesMessages(data.messages as any[])
|
activeSession.value.messages = mapHermesMessages(data.messages as any[])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ export interface SummarizerOptions {
|
|||||||
profile?: string
|
profile?: string
|
||||||
model?: string | null
|
model?: string | null
|
||||||
provider?: string | null
|
provider?: string | null
|
||||||
|
workerKey?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── Token counting ─────────────────────────────────────
|
// ─── Token counting ─────────────────────────────────────
|
||||||
@@ -454,6 +455,7 @@ export async function callSummarizer(
|
|||||||
|
|
||||||
const bridge = new AgentBridgeClient({ timeoutMs: timeoutMs + 15_000 })
|
const bridge = new AgentBridgeClient({ timeoutMs: timeoutMs + 15_000 })
|
||||||
const sessionId = `compress_${Date.now().toString(36)}_${randomUUID().replace(/-/g, '').slice(0, 12)}`
|
const sessionId = `compress_${Date.now().toString(36)}_${randomUUID().replace(/-/g, '').slice(0, 12)}`
|
||||||
|
const workerKey = options.workerKey || `${profile}:compression:${sessionId}`
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await bridge.request<AgentBridgeRunResult>({
|
const result = await bridge.request<AgentBridgeRunResult>({
|
||||||
@@ -462,6 +464,7 @@ export async function callSummarizer(
|
|||||||
message: prompt,
|
message: prompt,
|
||||||
conversation_history: convHistory,
|
conversation_history: convHistory,
|
||||||
profile,
|
profile,
|
||||||
|
worker_key: workerKey,
|
||||||
source: 'api_server',
|
source: 'api_server',
|
||||||
wait: true,
|
wait: true,
|
||||||
timeout: Math.ceil(timeoutMs / 1000),
|
timeout: Math.ceil(timeoutMs / 1000),
|
||||||
@@ -482,7 +485,7 @@ export async function callSummarizer(
|
|||||||
if (!output) throw new Error('Empty summarization response')
|
if (!output) throw new Error('Empty summarization response')
|
||||||
return output
|
return output
|
||||||
} finally {
|
} finally {
|
||||||
await bridge.destroy(sessionId, profile).catch(() => undefined)
|
await bridge.destroy(sessionId, profile, workerKey).catch(() => undefined)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ export class AgentBridgeClient {
|
|||||||
private summarizePayload(payload: Record<string, unknown>): Record<string, unknown> {
|
private summarizePayload(payload: Record<string, unknown>): Record<string, unknown> {
|
||||||
const action = String(payload.action || '')
|
const action = String(payload.action || '')
|
||||||
const summary: Record<string, unknown> = { action }
|
const summary: Record<string, unknown> = { action }
|
||||||
for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile']) {
|
for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile', 'worker_key']) {
|
||||||
if (payload[key] != null) summary[key] = payload[key]
|
if (payload[key] != null) summary[key] = payload[key]
|
||||||
}
|
}
|
||||||
if (Array.isArray(payload.conversation_history)) summary.conversation_history_count = payload.conversation_history.length
|
if (Array.isArray(payload.conversation_history)) summary.conversation_history_count = payload.conversation_history.length
|
||||||
@@ -569,11 +569,12 @@ export class AgentBridgeClient {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
destroy(sessionId: string, profile?: string): Promise<AgentBridgeResponse> {
|
destroy(sessionId: string, profile?: string, workerKey?: string): Promise<AgentBridgeResponse> {
|
||||||
return this.request({
|
return this.request({
|
||||||
action: 'destroy',
|
action: 'destroy',
|
||||||
session_id: sessionId,
|
session_id: sessionId,
|
||||||
...(profile ? { profile } : {}),
|
...(profile ? { profile } : {}),
|
||||||
|
...(workerKey ? { worker_key: workerKey } : {}),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2200,7 +2200,8 @@ class WorkerProcess:
|
|||||||
STARTUP_TIMEOUT_SECONDS = 120
|
STARTUP_TIMEOUT_SECONDS = 120
|
||||||
REQUEST_TIMEOUT_SECONDS = 120
|
REQUEST_TIMEOUT_SECONDS = 120
|
||||||
|
|
||||||
def __init__(self, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None:
|
def __init__(self, key: str, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None:
|
||||||
|
self.key = key or profile or "default"
|
||||||
self.profile = profile or "default"
|
self.profile = profile or "default"
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.agent_root = agent_root
|
self.agent_root = agent_root
|
||||||
@@ -2263,14 +2264,14 @@ class WorkerProcess:
|
|||||||
for line in proc.stderr:
|
for line in proc.stderr:
|
||||||
text = line.rstrip()
|
text = line.rstrip()
|
||||||
if text:
|
if text:
|
||||||
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
|
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
|
||||||
|
|
||||||
threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.profile}").start()
|
threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.key}").start()
|
||||||
|
|
||||||
def _wait_ready(self) -> None:
|
def _wait_ready(self) -> None:
|
||||||
proc = self.process
|
proc = self.process
|
||||||
if proc is None or proc.stdout is None:
|
if proc is None or proc.stdout is None:
|
||||||
raise RuntimeError(f"profile worker {self.profile} did not start")
|
raise RuntimeError(f"profile worker {self.key} did not start")
|
||||||
lines: queue.Queue[str | None] = queue.Queue()
|
lines: queue.Queue[str | None] = queue.Queue()
|
||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
|
|
||||||
@@ -2281,17 +2282,17 @@ class WorkerProcess:
|
|||||||
if ready_event.is_set():
|
if ready_event.is_set():
|
||||||
text = line.rstrip()
|
text = line.rstrip()
|
||||||
if text:
|
if text:
|
||||||
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
|
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
|
||||||
else:
|
else:
|
||||||
lines.put(line)
|
lines.put(line)
|
||||||
finally:
|
finally:
|
||||||
lines.put(None)
|
lines.put(None)
|
||||||
|
|
||||||
threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.profile}").start()
|
threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.key}").start()
|
||||||
deadline = time.time() + self.STARTUP_TIMEOUT_SECONDS
|
deadline = time.time() + self.STARTUP_TIMEOUT_SECONDS
|
||||||
while time.time() < deadline:
|
while time.time() < deadline:
|
||||||
if proc.poll() is not None:
|
if proc.poll() is not None:
|
||||||
raise RuntimeError(f"profile worker {self.profile} exited before ready")
|
raise RuntimeError(f"profile worker {self.key} exited before ready")
|
||||||
try:
|
try:
|
||||||
line = lines.get(timeout=0.1)
|
line = lines.get(timeout=0.1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
@@ -2301,7 +2302,7 @@ class WorkerProcess:
|
|||||||
continue
|
continue
|
||||||
text = line.strip()
|
text = line.strip()
|
||||||
if text:
|
if text:
|
||||||
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
|
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
|
||||||
try:
|
try:
|
||||||
data = json.loads(text)
|
data = json.loads(text)
|
||||||
if data.get("event") == "ready":
|
if data.get("event") == "ready":
|
||||||
@@ -2310,7 +2311,7 @@ class WorkerProcess:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self.stop()
|
self.stop()
|
||||||
raise RuntimeError(f"profile worker {self.profile} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s")
|
raise RuntimeError(f"profile worker {self.key} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s")
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@@ -2337,8 +2338,8 @@ class WorkerProcess:
|
|||||||
return _send_bridge_request(self.endpoint, req, self.REQUEST_TIMEOUT_SECONDS)
|
return _send_bridge_request(self.endpoint, req, self.REQUEST_TIMEOUT_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
def _worker_endpoint(profile: str) -> str:
|
def _worker_endpoint(key: str) -> str:
|
||||||
safe = hashlib.sha256(profile.encode("utf-8")).hexdigest()[:16]
|
safe = hashlib.sha256(key.encode("utf-8")).hexdigest()[:16]
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
port_base = int(os.environ.get("HERMES_AGENT_BRIDGE_WORKER_PORT_BASE", "18780"))
|
port_base = int(os.environ.get("HERMES_AGENT_BRIDGE_WORKER_PORT_BASE", "18780"))
|
||||||
return f"tcp://127.0.0.1:{port_base + int(safe[:4], 16) % 1000}"
|
return f"tcp://127.0.0.1:{port_base + int(safe[:4], 16) % 1000}"
|
||||||
@@ -2533,11 +2534,17 @@ class BridgeBroker:
|
|||||||
self.hermes_home = hermes_home
|
self.hermes_home = hermes_home
|
||||||
self._workers: dict[str, WorkerProcess] = {}
|
self._workers: dict[str, WorkerProcess] = {}
|
||||||
self._run_profile: dict[str, str] = {}
|
self._run_profile: dict[str, str] = {}
|
||||||
|
self._run_worker_key: dict[str, str] = {}
|
||||||
self._running_run_profile: dict[str, str] = {}
|
self._running_run_profile: dict[str, str] = {}
|
||||||
|
self._running_run_worker_key: dict[str, str] = {}
|
||||||
self._session_profile: dict[str, str] = {}
|
self._session_profile: dict[str, str] = {}
|
||||||
|
self._session_worker_key: dict[str, str] = {}
|
||||||
self._approval_profile: dict[str, str] = {}
|
self._approval_profile: dict[str, str] = {}
|
||||||
|
self._approval_worker_key: dict[str, str] = {}
|
||||||
self._clarify_profile: dict[str, str] = {}
|
self._clarify_profile: dict[str, str] = {}
|
||||||
|
self._clarify_worker_key: dict[str, str] = {}
|
||||||
self._compression_profile: dict[str, str] = {}
|
self._compression_profile: dict[str, str] = {}
|
||||||
|
self._compression_worker_key: dict[str, str] = {}
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
self._stop = threading.Event()
|
self._stop = threading.Event()
|
||||||
self._last_gc = time.time()
|
self._last_gc = time.time()
|
||||||
@@ -2546,58 +2553,73 @@ class BridgeBroker:
|
|||||||
profile = str(value or "").strip()
|
profile = str(value or "").strip()
|
||||||
return profile or "default"
|
return profile or "default"
|
||||||
|
|
||||||
def _worker_for_profile(self, profile: str) -> WorkerProcess:
|
def _normalize_worker_key(self, profile: str, value: Any = None) -> str:
|
||||||
|
worker_key = str(value or "").strip()
|
||||||
|
return worker_key or profile
|
||||||
|
|
||||||
|
def _worker_for_profile(self, profile: str, worker_key: str | None = None) -> WorkerProcess:
|
||||||
profile = self._normalize_profile(profile)
|
profile = self._normalize_profile(profile)
|
||||||
|
key = self._normalize_worker_key(profile, worker_key)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
worker = self._workers.get(profile)
|
worker = self._workers.get(key)
|
||||||
if worker is None:
|
if worker is None:
|
||||||
worker = WorkerProcess(profile, _worker_endpoint(profile), self.agent_root, self.hermes_home)
|
worker = WorkerProcess(key, profile, _worker_endpoint(key), self.agent_root, self.hermes_home)
|
||||||
self._workers[profile] = worker
|
self._workers[key] = worker
|
||||||
return worker
|
return worker
|
||||||
|
|
||||||
def _profile_for_run(self, run_id: str) -> str:
|
def _route_for_run(self, run_id: str) -> tuple[str, str | None]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
profile = self._run_profile.get(run_id)
|
profile = self._run_profile.get(run_id)
|
||||||
|
worker_key = self._run_worker_key.get(run_id)
|
||||||
if not profile:
|
if not profile:
|
||||||
raise KeyError(f"unknown run: {run_id}")
|
raise KeyError(f"unknown run: {run_id}")
|
||||||
return profile
|
return profile, worker_key
|
||||||
|
|
||||||
def _profile_for_session(self, session_id: str, fallback_profile: Any = None) -> str:
|
def _route_for_session(self, session_id: str, fallback_profile: Any = None, worker_key: Any = None) -> tuple[str, str | None]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
profile = self._session_profile.get(session_id)
|
profile = self._session_profile.get(session_id)
|
||||||
|
stored_worker_key = self._session_worker_key.get(session_id)
|
||||||
if not profile:
|
if not profile:
|
||||||
fallback = self._normalize_profile(fallback_profile)
|
fallback = self._normalize_profile(fallback_profile)
|
||||||
if fallback_profile is not None and fallback:
|
if fallback_profile is not None and fallback:
|
||||||
return fallback
|
return fallback, self._normalize_worker_key(fallback, worker_key)
|
||||||
raise KeyError(f"unknown session: {session_id}")
|
raise KeyError(f"unknown session: {session_id}")
|
||||||
return profile
|
return profile, self._normalize_worker_key(profile, worker_key) if worker_key is not None else stored_worker_key
|
||||||
|
|
||||||
def _record_response_routes(self, profile: str, resp: dict[str, Any]) -> None:
|
def _record_response_routes(self, profile: str, worker_key: str, resp: dict[str, Any]) -> None:
|
||||||
run_id = str(resp.get("run_id") or "")
|
run_id = str(resp.get("run_id") or "")
|
||||||
session_id = str(resp.get("session_id") or "")
|
session_id = str(resp.get("session_id") or "")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if run_id:
|
if run_id:
|
||||||
self._run_profile[run_id] = profile
|
self._run_profile[run_id] = profile
|
||||||
|
self._run_worker_key[run_id] = worker_key
|
||||||
if resp.get("status") == "running":
|
if resp.get("status") == "running":
|
||||||
self._running_run_profile[run_id] = profile
|
self._running_run_profile[run_id] = profile
|
||||||
|
self._running_run_worker_key[run_id] = worker_key
|
||||||
else:
|
else:
|
||||||
self._running_run_profile.pop(run_id, None)
|
self._running_run_profile.pop(run_id, None)
|
||||||
|
self._running_run_worker_key.pop(run_id, None)
|
||||||
if session_id:
|
if session_id:
|
||||||
self._session_profile[session_id] = profile
|
self._session_profile[session_id] = profile
|
||||||
|
self._session_worker_key[session_id] = worker_key
|
||||||
for event in resp.get("events") or []:
|
for event in resp.get("events") or []:
|
||||||
if not isinstance(event, dict):
|
if not isinstance(event, dict):
|
||||||
continue
|
continue
|
||||||
approval_id = str(event.get("approval_id") or "")
|
approval_id = str(event.get("approval_id") or "")
|
||||||
if approval_id:
|
if approval_id:
|
||||||
self._approval_profile[approval_id] = profile
|
self._approval_profile[approval_id] = profile
|
||||||
|
self._approval_worker_key[approval_id] = worker_key
|
||||||
clarify_id = str(event.get("clarify_id") or "")
|
clarify_id = str(event.get("clarify_id") or "")
|
||||||
if clarify_id:
|
if clarify_id:
|
||||||
self._clarify_profile[clarify_id] = profile
|
self._clarify_profile[clarify_id] = profile
|
||||||
|
self._clarify_worker_key[clarify_id] = worker_key
|
||||||
request_id = str(event.get("request_id") or "")
|
request_id = str(event.get("request_id") or "")
|
||||||
if event.get("event") == "bridge.compression.requested" and request_id:
|
if event.get("event") == "bridge.compression.requested" and request_id:
|
||||||
self._compression_profile[request_id] = profile
|
self._compression_profile[request_id] = profile
|
||||||
|
self._compression_worker_key[request_id] = worker_key
|
||||||
if event.get("event") in {"bridge.compression.completed", "bridge.compression.failed"} and request_id:
|
if event.get("event") in {"bridge.compression.completed", "bridge.compression.failed"} and request_id:
|
||||||
self._compression_profile.pop(request_id, None)
|
self._compression_profile.pop(request_id, None)
|
||||||
|
self._compression_worker_key.pop(request_id, None)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
self._stop.set()
|
self._stop.set()
|
||||||
@@ -2605,20 +2627,29 @@ class BridgeBroker:
|
|||||||
workers = list(self._workers.values())
|
workers = list(self._workers.values())
|
||||||
self._workers.clear()
|
self._workers.clear()
|
||||||
self._run_profile.clear()
|
self._run_profile.clear()
|
||||||
|
self._run_worker_key.clear()
|
||||||
self._running_run_profile.clear()
|
self._running_run_profile.clear()
|
||||||
|
self._running_run_worker_key.clear()
|
||||||
self._session_profile.clear()
|
self._session_profile.clear()
|
||||||
|
self._session_worker_key.clear()
|
||||||
self._approval_profile.clear()
|
self._approval_profile.clear()
|
||||||
|
self._approval_worker_key.clear()
|
||||||
self._clarify_profile.clear()
|
self._clarify_profile.clear()
|
||||||
|
self._clarify_worker_key.clear()
|
||||||
self._compression_profile.clear()
|
self._compression_profile.clear()
|
||||||
|
self._compression_worker_key.clear()
|
||||||
for worker in workers:
|
for worker in workers:
|
||||||
worker.stop()
|
worker.stop()
|
||||||
|
|
||||||
def _forward(self, profile: str, req: dict[str, Any]) -> dict[str, Any]:
|
def _forward(self, profile: str, req: dict[str, Any], worker_key: str | None = None) -> dict[str, Any]:
|
||||||
worker = self._worker_for_profile(profile)
|
profile = self._normalize_profile(profile)
|
||||||
|
key = self._normalize_worker_key(profile, worker_key)
|
||||||
|
worker = self._worker_for_profile(profile, key)
|
||||||
forwarded = dict(req)
|
forwarded = dict(req)
|
||||||
forwarded["profile"] = profile
|
forwarded["profile"] = profile
|
||||||
|
forwarded.pop("worker_key", None)
|
||||||
resp = worker.request(forwarded)
|
resp = worker.request(forwarded)
|
||||||
self._record_response_routes(profile, resp)
|
self._record_response_routes(profile, key, resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def handle(self, req: dict[str, Any]) -> dict[str, Any]:
|
def handle(self, req: dict[str, Any]) -> dict[str, Any]:
|
||||||
@@ -2629,15 +2660,16 @@ class BridgeBroker:
|
|||||||
if action == "ping":
|
if action == "ping":
|
||||||
with self._lock:
|
with self._lock:
|
||||||
worker_details = {
|
worker_details = {
|
||||||
profile: {
|
key: {
|
||||||
"running": worker.running,
|
"running": worker.running,
|
||||||
"pid": worker.pid,
|
"pid": worker.pid,
|
||||||
"endpoint": worker.endpoint,
|
"endpoint": worker.endpoint,
|
||||||
|
"profile": getattr(worker, "profile", key),
|
||||||
"last_used_at": worker.last_used_at,
|
"last_used_at": worker.last_used_at,
|
||||||
}
|
}
|
||||||
for profile, worker in self._workers.items()
|
for key, worker in self._workers.items()
|
||||||
}
|
}
|
||||||
workers = {profile: details["running"] for profile, details in worker_details.items()}
|
workers = {key: details["running"] for key, details in worker_details.items()}
|
||||||
sessions_by_profile: dict[str, int] = {}
|
sessions_by_profile: dict[str, int] = {}
|
||||||
for profile in self._session_profile.values():
|
for profile in self._session_profile.values():
|
||||||
sessions_by_profile[profile] = sessions_by_profile.get(profile, 0) + 1
|
sessions_by_profile[profile] = sessions_by_profile.get(profile, 0) + 1
|
||||||
@@ -2664,29 +2696,32 @@ class BridgeBroker:
|
|||||||
|
|
||||||
if action == "worker_ping":
|
if action == "worker_ping":
|
||||||
profile = self._normalize_profile(req.get("profile"))
|
profile = self._normalize_profile(req.get("profile"))
|
||||||
resp = self._forward(profile, {"action": "ping"})
|
worker_key = self._normalize_worker_key(profile, req.get("worker_key"))
|
||||||
|
resp = self._forward(profile, {"action": "ping"}, worker_key)
|
||||||
resp["worker_profile"] = profile
|
resp["worker_profile"] = profile
|
||||||
|
resp["worker_key"] = worker_key
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
if action == "chat":
|
if action == "chat":
|
||||||
profile = self._normalize_profile(req.get("profile"))
|
profile = self._normalize_profile(req.get("profile"))
|
||||||
return self._forward(profile, req)
|
return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
|
||||||
|
|
||||||
if action == "context_estimate":
|
if action == "context_estimate":
|
||||||
profile = self._normalize_profile(req.get("profile"))
|
profile = self._normalize_profile(req.get("profile"))
|
||||||
return self._forward(profile, req)
|
return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
|
||||||
|
|
||||||
if action in {"get_result", "get_output"}:
|
if action in {"get_result", "get_output"}:
|
||||||
profile = self._profile_for_run(str(req.get("run_id") or ""))
|
profile, worker_key = self._route_for_run(str(req.get("run_id") or ""))
|
||||||
return self._forward(profile, req)
|
return self._forward(profile, req, worker_key)
|
||||||
|
|
||||||
if action in {"interrupt", "steer", "command", "goal_evaluate", "goal_pause", "status", "get_history", "destroy"}:
|
if action in {"interrupt", "steer", "command", "goal_evaluate", "goal_pause", "status", "get_history", "destroy"}:
|
||||||
session_id = str(req.get("session_id") or "")
|
session_id = str(req.get("session_id") or "")
|
||||||
profile = self._profile_for_session(session_id, req.get("profile"))
|
profile, worker_key = self._route_for_session(session_id, req.get("profile"), req.get("worker_key") if "worker_key" in req else None)
|
||||||
resp = self._forward(profile, req)
|
resp = self._forward(profile, req, worker_key)
|
||||||
if action == "destroy":
|
if action == "destroy":
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._session_profile.pop(session_id, None)
|
self._session_profile.pop(session_id, None)
|
||||||
|
self._session_worker_key.pop(session_id, None)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
if action == "approval_respond":
|
if action == "approval_respond":
|
||||||
@@ -2695,9 +2730,10 @@ class BridgeBroker:
|
|||||||
raise ValueError("approval_id is required")
|
raise ValueError("approval_id is required")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
profile = self._approval_profile.get(approval_id)
|
profile = self._approval_profile.get(approval_id)
|
||||||
|
worker_key = self._approval_worker_key.get(approval_id)
|
||||||
if not profile:
|
if not profile:
|
||||||
raise KeyError(f"unknown approval request: {approval_id}")
|
raise KeyError(f"unknown approval request: {approval_id}")
|
||||||
return self._forward(profile, req)
|
return self._forward(profile, req, worker_key)
|
||||||
|
|
||||||
if action == "clarify_respond":
|
if action == "clarify_respond":
|
||||||
clarify_id = str(req.get("clarify_id") or "").strip()
|
clarify_id = str(req.get("clarify_id") or "").strip()
|
||||||
@@ -2705,9 +2741,10 @@ class BridgeBroker:
|
|||||||
raise ValueError("clarify_id is required")
|
raise ValueError("clarify_id is required")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
profile = self._clarify_profile.get(clarify_id)
|
profile = self._clarify_profile.get(clarify_id)
|
||||||
|
worker_key = self._clarify_worker_key.get(clarify_id)
|
||||||
if not profile:
|
if not profile:
|
||||||
raise KeyError(f"unknown clarify request: {clarify_id}")
|
raise KeyError(f"unknown clarify request: {clarify_id}")
|
||||||
return self._forward(profile, req)
|
return self._forward(profile, req, worker_key)
|
||||||
|
|
||||||
if action == "compression_respond":
|
if action == "compression_respond":
|
||||||
request_id = str(req.get("request_id") or "").strip()
|
request_id = str(req.get("request_id") or "").strip()
|
||||||
@@ -2715,20 +2752,27 @@ class BridgeBroker:
|
|||||||
raise ValueError("request_id is required")
|
raise ValueError("request_id is required")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
profile = self._compression_profile.get(request_id)
|
profile = self._compression_profile.get(request_id)
|
||||||
|
worker_key = self._compression_worker_key.get(request_id)
|
||||||
if not profile:
|
if not profile:
|
||||||
raise KeyError(f"unknown compression request: {request_id}")
|
raise KeyError(f"unknown compression request: {request_id}")
|
||||||
return self._forward(profile, req)
|
return self._forward(profile, req, worker_key)
|
||||||
|
|
||||||
if action == "destroy_all":
|
if action == "destroy_all":
|
||||||
with self._lock:
|
with self._lock:
|
||||||
workers = list(self._workers.values())
|
workers = list(self._workers.values())
|
||||||
self._workers.clear()
|
self._workers.clear()
|
||||||
self._run_profile.clear()
|
self._run_profile.clear()
|
||||||
|
self._run_worker_key.clear()
|
||||||
self._running_run_profile.clear()
|
self._running_run_profile.clear()
|
||||||
|
self._running_run_worker_key.clear()
|
||||||
self._session_profile.clear()
|
self._session_profile.clear()
|
||||||
|
self._session_worker_key.clear()
|
||||||
self._approval_profile.clear()
|
self._approval_profile.clear()
|
||||||
|
self._approval_worker_key.clear()
|
||||||
self._clarify_profile.clear()
|
self._clarify_profile.clear()
|
||||||
|
self._clarify_worker_key.clear()
|
||||||
self._compression_profile.clear()
|
self._compression_profile.clear()
|
||||||
|
self._compression_worker_key.clear()
|
||||||
destroyed = 0
|
destroyed = 0
|
||||||
for worker in workers:
|
for worker in workers:
|
||||||
try:
|
try:
|
||||||
@@ -2744,24 +2788,39 @@ class BridgeBroker:
|
|||||||
if action == "destroy_profile":
|
if action == "destroy_profile":
|
||||||
profile = self._normalize_profile(req.get("profile"))
|
profile = self._normalize_profile(req.get("profile"))
|
||||||
with self._lock:
|
with self._lock:
|
||||||
worker = self._workers.pop(profile, None)
|
workers = [
|
||||||
|
worker
|
||||||
|
for key, worker in list(self._workers.items())
|
||||||
|
if getattr(worker, "profile", key) == profile
|
||||||
|
]
|
||||||
|
for worker in workers:
|
||||||
|
self._workers.pop(worker.key, None)
|
||||||
self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile}
|
self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile}
|
||||||
|
self._run_worker_key = {key: value for key, value in self._run_worker_key.items() if key in self._run_profile}
|
||||||
self._running_run_profile = {key: value for key, value in self._running_run_profile.items() if value != profile}
|
self._running_run_profile = {key: value for key, value in self._running_run_profile.items() if value != profile}
|
||||||
|
self._running_run_worker_key = {key: value for key, value in self._running_run_worker_key.items() if key in self._running_run_profile}
|
||||||
self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile}
|
self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile}
|
||||||
|
self._session_worker_key = {key: value for key, value in self._session_worker_key.items() if key in self._session_profile}
|
||||||
self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile}
|
self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile}
|
||||||
|
self._approval_worker_key = {key: value for key, value in self._approval_worker_key.items() if key in self._approval_profile}
|
||||||
self._clarify_profile = {key: value for key, value in self._clarify_profile.items() if value != profile}
|
self._clarify_profile = {key: value for key, value in self._clarify_profile.items() if value != profile}
|
||||||
|
self._clarify_worker_key = {key: value for key, value in self._clarify_worker_key.items() if key in self._clarify_profile}
|
||||||
self._compression_profile = {key: value for key, value in self._compression_profile.items() if value != profile}
|
self._compression_profile = {key: value for key, value in self._compression_profile.items() if value != profile}
|
||||||
|
self._compression_worker_key = {key: value for key, value in self._compression_worker_key.items() if key in self._compression_profile}
|
||||||
|
|
||||||
if worker is None or not worker.running:
|
if not workers:
|
||||||
if worker is not None:
|
|
||||||
worker.stop()
|
|
||||||
return {"profile": profile, "destroyed": 0}
|
return {"profile": profile, "destroyed": 0}
|
||||||
|
|
||||||
|
destroyed = 0
|
||||||
|
for worker in workers:
|
||||||
|
if not worker.running:
|
||||||
|
worker.stop()
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
resp = worker.request({"action": "destroy_all"})
|
resp = worker.request({"action": "destroy_all"})
|
||||||
destroyed = int(resp.get("destroyed") or 0)
|
destroyed += int(resp.get("destroyed") or 0)
|
||||||
except Exception:
|
except Exception:
|
||||||
destroyed = 0
|
pass
|
||||||
finally:
|
finally:
|
||||||
worker.stop()
|
worker.stop()
|
||||||
return {"profile": profile, "destroyed": destroyed}
|
return {"profile": profile, "destroyed": destroyed}
|
||||||
@@ -2770,14 +2829,15 @@ class BridgeBroker:
|
|||||||
sessions: list[Any] = []
|
sessions: list[Any] = []
|
||||||
with self._lock:
|
with self._lock:
|
||||||
workers = list(self._workers.items())
|
workers = list(self._workers.items())
|
||||||
for profile, worker in workers:
|
for key, worker in workers:
|
||||||
if not worker.running:
|
if not worker.running:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
resp = worker.request({"action": "list"})
|
resp = worker.request({"action": "list"})
|
||||||
for session in resp.get("sessions") or []:
|
for session in resp.get("sessions") or []:
|
||||||
if isinstance(session, dict):
|
if isinstance(session, dict):
|
||||||
session.setdefault("profile", profile)
|
session.setdefault("profile", getattr(worker, "profile", key))
|
||||||
|
session.setdefault("worker_key", key)
|
||||||
sessions.append(session)
|
sessions.append(session)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -2826,12 +2886,12 @@ class BridgeBroker:
|
|||||||
self._last_gc = now
|
self._last_gc = now
|
||||||
with self._lock:
|
with self._lock:
|
||||||
idle = [
|
idle = [
|
||||||
profile for profile, worker in self._workers.items()
|
key for key, worker in self._workers.items()
|
||||||
if worker.running and now - worker.last_used_at > self.IDLE_TIMEOUT_SECONDS
|
if worker.running and now - worker.last_used_at > self.IDLE_TIMEOUT_SECONDS
|
||||||
]
|
]
|
||||||
for profile in idle:
|
for key in idle:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
worker = self._workers.pop(profile, None)
|
worker = self._workers.pop(key, None)
|
||||||
if worker:
|
if worker:
|
||||||
worker.stop()
|
worker.stop()
|
||||||
|
|
||||||
|
|||||||
@@ -195,7 +195,8 @@ export async function buildCompressedHistory(
|
|||||||
emit: (event: string, payload: any) => void,
|
emit: (event: string, payload: any) => void,
|
||||||
sessionMap: Map<string, SessionState>,
|
sessionMap: Map<string, SessionState>,
|
||||||
modelContext: { model?: string | null; provider?: string | null } = {},
|
modelContext: { model?: string | null; provider?: string | null } = {},
|
||||||
contextTokenEstimator?: (messages: ChatMessage[]) => Promise<number | null | undefined>,
|
contextTokenEstimator?: (messages: ChatMessage[], messageTokens: number) => Promise<number | null | undefined>,
|
||||||
|
currentInputTokens = 0,
|
||||||
): Promise<ChatMessage[]> {
|
): Promise<ChatMessage[]> {
|
||||||
try {
|
try {
|
||||||
let history = await buildDbHistory(sessionId, { excludeLastUser: true })
|
let history = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||||
@@ -213,14 +214,18 @@ export async function buildCompressedHistory(
|
|||||||
}
|
}
|
||||||
const cState = getOrCreateSession(sessionMap, sessionId)
|
const cState = getOrCreateSession(sessionMap, sessionId)
|
||||||
const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
||||||
const estimateFullContextTokens = async (messages: ChatMessage[], fallback: number) => {
|
const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||||
|
? Math.floor(currentInputTokens)
|
||||||
|
: 0
|
||||||
|
const estimateLocalContextTokens = async (messages: ChatMessage[], messageTokens: number) => {
|
||||||
|
const localMessageTokens = Math.max(0, Math.floor(messageTokens))
|
||||||
try {
|
try {
|
||||||
const estimate = await contextTokenEstimator?.(messages)
|
const estimate = await contextTokenEstimator?.(messages, localMessageTokens)
|
||||||
if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate)
|
if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate)
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.warn(err, '[context-compress] session=%s: full context token estimate failed; using message-only estimate', sessionId)
|
logger.warn(err, '[context-compress] session=%s: fixed context token estimate failed; using message-only estimate', sessionId)
|
||||||
}
|
}
|
||||||
return fallback
|
return localMessageTokens
|
||||||
}
|
}
|
||||||
const emitContextUsage = (contextTokens: number) => {
|
const emitContextUsage = (contextTokens: number) => {
|
||||||
cState.contextTokens = contextTokens
|
cState.contextTokens = contextTokens
|
||||||
@@ -236,10 +241,10 @@ export async function buildCompressedHistory(
|
|||||||
let totalTokens = messageOnlyTotalTokens
|
let totalTokens = messageOnlyTotalTokens
|
||||||
|
|
||||||
if (history.length === 0) {
|
if (history.length === 0) {
|
||||||
totalTokens = await estimateFullContextTokens([], 0)
|
totalTokens = await estimateLocalContextTokens([], Math.max(currentRunInputTokens, messageOnlyTotalTokens))
|
||||||
if (totalTokens > triggerTokens) {
|
if (totalTokens > triggerTokens) {
|
||||||
throw new ContextWindowTooSmallError(
|
throw new ContextWindowTooSmallError(
|
||||||
`Context window is too small: system prompt and tool schemas already use ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, or disable some tools.`,
|
`Context window is too small: fixed prompt/tool overhead plus the current input uses ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, shorten the input, or disable some tools.`,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if (totalTokens > 0) emitContextUsage(totalTokens)
|
if (totalTokens > 0) emitContextUsage(totalTokens)
|
||||||
@@ -254,13 +259,15 @@ export async function buildCompressedHistory(
|
|||||||
sessionId, snapshot.lastMessageIndex, history.length)
|
sessionId, snapshot.lastMessageIndex, history.length)
|
||||||
const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||||
const staleUsage = estimateUsageTokensFromMessages(staleHistory)
|
const staleUsage = estimateUsageTokensFromMessages(staleHistory)
|
||||||
totalTokens = await estimateFullContextTokens(staleHistory, staleUsage.inputTokens + staleUsage.outputTokens)
|
const staleMessageTokens = staleUsage.inputTokens + staleUsage.outputTokens
|
||||||
|
const staleRunMessageTokens = Math.max(staleMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
|
||||||
|
totalTokens = await estimateLocalContextTokens(staleHistory, staleRunMessageTokens)
|
||||||
emitContextUsage(totalTokens)
|
emitContextUsage(totalTokens)
|
||||||
logger.info({
|
logger.info({
|
||||||
sessionId,
|
sessionId,
|
||||||
profile,
|
profile,
|
||||||
messages: staleHistory.length,
|
messages: staleHistory.length,
|
||||||
messageOnlyTokens: staleUsage.inputTokens + staleUsage.outputTokens,
|
messageOnlyTokens: staleRunMessageTokens,
|
||||||
fullContextTokens: totalTokens,
|
fullContextTokens: totalTokens,
|
||||||
triggerTokens,
|
triggerTokens,
|
||||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||||
@@ -272,13 +279,15 @@ export async function buildCompressedHistory(
|
|||||||
const newMessages = history.slice(snapshot.lastMessageIndex + 1)
|
const newMessages = history.slice(snapshot.lastMessageIndex + 1)
|
||||||
const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||||
const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory)
|
const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||||
totalTokens = await estimateFullContextTokens(snapshotHistory, snapshotUsage.inputTokens + snapshotUsage.outputTokens)
|
const snapshotMessageTokens = snapshotUsage.inputTokens + snapshotUsage.outputTokens
|
||||||
|
const snapshotRunMessageTokens = Math.max(snapshotMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
|
||||||
|
totalTokens = await estimateLocalContextTokens(snapshotHistory, snapshotRunMessageTokens)
|
||||||
emitContextUsage(totalTokens)
|
emitContextUsage(totalTokens)
|
||||||
logger.info({
|
logger.info({
|
||||||
sessionId,
|
sessionId,
|
||||||
profile,
|
profile,
|
||||||
messages: snapshotHistory.length,
|
messages: snapshotHistory.length,
|
||||||
messageOnlyTokens: snapshotUsage.inputTokens + snapshotUsage.outputTokens,
|
messageOnlyTokens: snapshotRunMessageTokens,
|
||||||
fullContextTokens: totalTokens,
|
fullContextTokens: totalTokens,
|
||||||
triggerTokens,
|
triggerTokens,
|
||||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||||
@@ -289,22 +298,25 @@ export async function buildCompressedHistory(
|
|||||||
if (totalTokens <= triggerTokens) {
|
if (totalTokens <= triggerTokens) {
|
||||||
history = snapshotHistory
|
history = snapshotHistory
|
||||||
} else {
|
} else {
|
||||||
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
|
||||||
}
|
}
|
||||||
} else if (snapshot && staleSnapshot) {
|
} else if (snapshot && staleSnapshot) {
|
||||||
if (totalTokens <= triggerTokens) {
|
if (totalTokens <= triggerTokens) {
|
||||||
history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||||
} else {
|
} else {
|
||||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
totalTokens = await estimateFullContextTokens(history, totalTokens)
|
const historyUsage = estimateUsageTokensFromMessages(history)
|
||||||
|
const historyMessageTokens = historyUsage.inputTokens + historyUsage.outputTokens
|
||||||
|
const runMessageTokens = Math.max(historyMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
|
||||||
|
totalTokens = await estimateLocalContextTokens(history, runMessageTokens)
|
||||||
emitContextUsage(totalTokens)
|
emitContextUsage(totalTokens)
|
||||||
logger.info({
|
logger.info({
|
||||||
sessionId,
|
sessionId,
|
||||||
profile,
|
profile,
|
||||||
messages: history.length,
|
messages: history.length,
|
||||||
messageOnlyTokens: messageOnlyTotalTokens,
|
messageOnlyTokens: runMessageTokens,
|
||||||
fullContextTokens: totalTokens,
|
fullContextTokens: totalTokens,
|
||||||
triggerTokens,
|
triggerTokens,
|
||||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||||
@@ -318,7 +330,7 @@ export async function buildCompressedHistory(
|
|||||||
if (totalTokens <= triggerTokens) {
|
if (totalTokens <= triggerTokens) {
|
||||||
logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens)
|
logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens)
|
||||||
} else {
|
} else {
|
||||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -342,8 +354,12 @@ export async function compressHistory(
|
|||||||
sessionMap: Map<string, SessionState>,
|
sessionMap: Map<string, SessionState>,
|
||||||
modelContext: { model?: string | null; provider?: string | null } = {},
|
modelContext: { model?: string | null; provider?: string | null } = {},
|
||||||
compressionConfig?: Partial<CompressorConfig>,
|
compressionConfig?: Partial<CompressorConfig>,
|
||||||
|
currentInputTokens = 0,
|
||||||
): Promise<ChatMessage[]> {
|
): Promise<ChatMessage[]> {
|
||||||
const msgCount = newMessagesOnly ? newMessagesOnly.length : history.length
|
const msgCount = newMessagesOnly ? newMessagesOnly.length : history.length
|
||||||
|
const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||||
|
? Math.floor(currentInputTokens)
|
||||||
|
: 0
|
||||||
pushState(sessionMap, sessionId, 'compression.started', {
|
pushState(sessionMap, sessionId, 'compression.started', {
|
||||||
event: 'compression.started', message_count: msgCount, token_count: totalTokens,
|
event: 'compression.started', message_count: msgCount, token_count: totalTokens,
|
||||||
})
|
})
|
||||||
@@ -353,14 +369,22 @@ export async function compressHistory(
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const session = getSession(sessionId)
|
const session = getSession(sessionId)
|
||||||
|
const summarizerProfile = session?.profile || 'default'
|
||||||
const compressor = new ChatContextCompressor({ config: compressionConfig })
|
const compressor = new ChatContextCompressor({ config: compressionConfig })
|
||||||
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
|
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
|
||||||
profile: session?.profile,
|
profile: summarizerProfile,
|
||||||
model: modelContext.model || session?.model,
|
model: modelContext.model || session?.model,
|
||||||
provider: modelContext.provider || session?.provider,
|
provider: modelContext.provider || session?.provider,
|
||||||
|
workerKey: `${summarizerProfile}:compression:${sessionId}`,
|
||||||
})
|
})
|
||||||
const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
||||||
const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens
|
const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens
|
||||||
|
const resultUsage = estimateUsageTokensFromMessages(result.messages)
|
||||||
|
const resultMessageTokens = resultUsage.inputTokens + resultUsage.outputTokens
|
||||||
|
const compressedRunMessageTokens = Math.max(
|
||||||
|
compressedAfterTokens,
|
||||||
|
resultMessageTokens + currentRunInputTokens,
|
||||||
|
)
|
||||||
const compressedMeta: any = {
|
const compressedMeta: any = {
|
||||||
event: 'compression.completed' as const,
|
event: 'compression.completed' as const,
|
||||||
compressed: result.meta.compressed,
|
compressed: result.meta.compressed,
|
||||||
@@ -368,15 +392,15 @@ export async function compressHistory(
|
|||||||
totalMessages: result.meta.totalMessages,
|
totalMessages: result.meta.totalMessages,
|
||||||
resultMessages: result.messages.length,
|
resultMessages: result.messages.length,
|
||||||
beforeTokens: totalTokens,
|
beforeTokens: totalTokens,
|
||||||
afterTokens: compressedAfterTokens,
|
afterTokens: compressedRunMessageTokens,
|
||||||
summaryTokens: result.meta.summaryTokenEstimate,
|
summaryTokens: result.meta.summaryTokenEstimate,
|
||||||
verbatimCount: result.meta.verbatimCount,
|
verbatimCount: result.meta.verbatimCount,
|
||||||
compressedStartIndex: result.meta.compressedStartIndex,
|
compressedStartIndex: result.meta.compressedStartIndex,
|
||||||
}
|
}
|
||||||
replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta)
|
replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta)
|
||||||
logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)',
|
logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)',
|
||||||
sessionId, result.messages.length, compressedAfterTokens, totalTokens)
|
sessionId, result.messages.length, compressedRunMessageTokens, totalTokens)
|
||||||
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedAfterTokens, afterTokens)
|
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedRunMessageTokens, afterTokens)
|
||||||
if (compressedContextTokens != null) {
|
if (compressedContextTokens != null) {
|
||||||
compressedMeta.contextTokens = compressedContextTokens
|
compressedMeta.contextTokens = compressedContextTokens
|
||||||
}
|
}
|
||||||
@@ -403,6 +427,7 @@ export async function compressHistory(
|
|||||||
resultMessages: msgCount,
|
resultMessages: msgCount,
|
||||||
beforeTokens: totalTokens,
|
beforeTokens: totalTokens,
|
||||||
afterTokens: totalTokens,
|
afterTokens: totalTokens,
|
||||||
|
contextTokens: totalTokens,
|
||||||
summaryTokens: 0,
|
summaryTokens: 0,
|
||||||
verbatimCount: msgCount,
|
verbatimCount: msgCount,
|
||||||
compressedStartIndex: -1,
|
compressedStartIndex: -1,
|
||||||
@@ -458,10 +483,12 @@ export async function forceCompressBridgeHistory(
|
|||||||
}, '[chat-run-socket] bridge forced compression started')
|
}, '[chat-run-socket] bridge forced compression started')
|
||||||
|
|
||||||
const compressor = new ChatContextCompressor({ config: compressionConfig.compressor })
|
const compressor = new ChatContextCompressor({ config: compressionConfig.compressor })
|
||||||
|
const summarizerProfile = session?.profile || profile || 'default'
|
||||||
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
|
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
|
||||||
profile: session?.profile || profile,
|
profile: summarizerProfile,
|
||||||
model: session?.model,
|
model: session?.model,
|
||||||
provider: session?.provider,
|
provider: session?.provider,
|
||||||
|
workerKey: `${summarizerProfile}:compression:${sessionId}`,
|
||||||
})
|
})
|
||||||
const compressedMessages = result.messages.map(m => {
|
const compressedMessages = result.messages.map(m => {
|
||||||
const msg: any = { role: m.role, content: m.content }
|
const msg: any = { role: m.role, content: m.content }
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import { convertHistoryFormat } from './message-format'
|
|||||||
import { readSseFrames } from './sse-utils'
|
import { readSseFrames } from './sse-utils'
|
||||||
import { extractResponseText } from './response-utils'
|
import { extractResponseText } from './response-utils'
|
||||||
import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream'
|
import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream'
|
||||||
import { buildCompressedHistory, getOrCreateSession } from './compression'
|
import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, getOrCreateSession } from './compression'
|
||||||
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||||
import { handleMessage } from './message-format'
|
import { handleMessage } from './message-format'
|
||||||
import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
|
import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
|
||||||
@@ -37,6 +37,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
|||||||
|
|
||||||
let inputTokens: number
|
let inputTokens: number
|
||||||
let outputTokens: number
|
let outputTokens: number
|
||||||
|
let contextTokens: number | undefined
|
||||||
const snapshot = getCompressionSnapshot(sid)
|
const snapshot = getCompressionSnapshot(sid)
|
||||||
if (snapshot && snapshot.lastMessageIndex >= 0 && snapshot.lastMessageIndex < messages.length) {
|
if (snapshot && snapshot.lastMessageIndex >= 0 && snapshot.lastMessageIndex < messages.length) {
|
||||||
const newMessages = messages.slice(snapshot.lastMessageIndex + 1)
|
const newMessages = messages.slice(snapshot.lastMessageIndex + 1)
|
||||||
@@ -49,6 +50,20 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
|||||||
inputTokens = usage.inputTokens
|
inputTokens = usage.inputTokens
|
||||||
outputTokens = usage.outputTokens
|
outputTokens = usage.outputTokens
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
|
const session = getSession(sid)
|
||||||
|
const dbHistory = await buildDbHistory(sid, { excludeLastUser: false })
|
||||||
|
const snapshotHistory = await buildSnapshotAwareHistory(
|
||||||
|
sid,
|
||||||
|
session?.profile || 'default',
|
||||||
|
dbHistory,
|
||||||
|
{ model: session?.model, provider: session?.provider },
|
||||||
|
)
|
||||||
|
const contextUsage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||||
|
contextTokens = contextUsage.inputTokens + contextUsage.outputTokens
|
||||||
|
} catch (err) {
|
||||||
|
logger.warn(err, '[chat-run-socket] failed to calculate snapshot-aware context tokens for session %s', sid)
|
||||||
|
}
|
||||||
|
|
||||||
logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length)
|
logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length)
|
||||||
return {
|
return {
|
||||||
@@ -57,6 +72,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
|||||||
events: [],
|
events: [],
|
||||||
inputTokens,
|
inputTokens,
|
||||||
outputTokens,
|
outputTokens,
|
||||||
|
contextTokens,
|
||||||
queue: [],
|
queue: [],
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import {
|
|||||||
contextTokensWithCachedOverhead,
|
contextTokensWithCachedOverhead,
|
||||||
estimateUsageTokensFromMessages,
|
estimateUsageTokensFromMessages,
|
||||||
getCachedBridgeContextOverhead,
|
getCachedBridgeContextOverhead,
|
||||||
updateContextTokenUsage,
|
|
||||||
updateMessageContextTokenUsage,
|
updateMessageContextTokenUsage,
|
||||||
} from './usage'
|
} from './usage'
|
||||||
import {
|
import {
|
||||||
@@ -123,6 +122,66 @@ function cacheBridgeContext(state: SessionState, data: Record<string, unknown> |
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function bridgeContextMatches(
|
||||||
|
state: SessionState,
|
||||||
|
expected: { profile: string; model?: string | null; provider?: string | null },
|
||||||
|
): boolean {
|
||||||
|
const context = state.bridgeContext
|
||||||
|
if (!context) return false
|
||||||
|
if (context.profile && context.profile !== expected.profile) return false
|
||||||
|
if (expected.model && context.model && context.model !== expected.model) return false
|
||||||
|
if (expected.provider && context.provider && context.provider !== expected.provider) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
async function ensureBridgeFixedContext(args: {
|
||||||
|
sessionId: string
|
||||||
|
profile: string
|
||||||
|
model?: string | null
|
||||||
|
provider?: string | null
|
||||||
|
instructions: string
|
||||||
|
state: SessionState
|
||||||
|
bridge: AgentBridgeClient
|
||||||
|
refresh?: boolean
|
||||||
|
}): Promise<number | undefined> {
|
||||||
|
const cached = bridgeContextMatches(args.state, args)
|
||||||
|
? getCachedBridgeContextOverhead(args.state)
|
||||||
|
: undefined
|
||||||
|
if (!args.refresh && cached != null) return cached
|
||||||
|
|
||||||
|
try {
|
||||||
|
const estimate = await args.bridge.contextEstimate(
|
||||||
|
args.sessionId,
|
||||||
|
[],
|
||||||
|
args.instructions,
|
||||||
|
args.profile,
|
||||||
|
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
|
||||||
|
)
|
||||||
|
cacheBridgeContext(args.state, estimate)
|
||||||
|
const fixedContextTokens = getCachedBridgeContextOverhead(args.state)
|
||||||
|
bridgeLogger.info({
|
||||||
|
sessionId: args.sessionId,
|
||||||
|
profile: args.profile,
|
||||||
|
model: args.model,
|
||||||
|
provider: args.provider,
|
||||||
|
toolCount: estimate.tool_count,
|
||||||
|
systemPromptChars: estimate.system_prompt_chars,
|
||||||
|
fixedContextTokens,
|
||||||
|
}, '[chat-run-socket] fixed context estimate')
|
||||||
|
return fixedContextTokens
|
||||||
|
} catch (err) {
|
||||||
|
bridgeLogger.warn({
|
||||||
|
err: err instanceof Error ? { message: err.message, name: err.name } : err,
|
||||||
|
sessionId: args.sessionId,
|
||||||
|
profile: args.profile,
|
||||||
|
model: args.model,
|
||||||
|
provider: args.provider,
|
||||||
|
cachedFixedContextTokens: cached,
|
||||||
|
}, '[chat-run-socket] fixed context estimate failed')
|
||||||
|
return cached
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export async function handleBridgeRun(
|
export async function handleBridgeRun(
|
||||||
nsp: ReturnType<Server['of']>,
|
nsp: ReturnType<Server['of']>,
|
||||||
socket: Socket,
|
socket: Socket,
|
||||||
@@ -195,6 +254,9 @@ export async function handleBridgeRun(
|
|||||||
|
|
||||||
const displayInput = data.display_input === undefined ? input : data.display_input
|
const displayInput = data.display_input === undefined ? input : data.display_input
|
||||||
const inputStr = displayInput == null ? '' : contentBlocksToString(displayInput)
|
const inputStr = displayInput == null ? '' : contentBlocksToString(displayInput)
|
||||||
|
const actualInputStr = contentBlocksToString(input)
|
||||||
|
const currentInputUsage = estimateUsageTokensFromMessages([{ role: 'user', content: actualInputStr }])
|
||||||
|
const currentInputTokens = currentInputUsage.inputTokens
|
||||||
const shouldPersistUserMessage = !skipUserMessage && displayInput !== null
|
const shouldPersistUserMessage = !skipUserMessage && displayInput !== null
|
||||||
const displayRole = data.display_role === 'command' ? 'command' : 'user'
|
const displayRole = data.display_role === 'command' ? 'command' : 'user'
|
||||||
let messageId: number | string | undefined
|
let messageId: number | string | undefined
|
||||||
@@ -257,33 +319,32 @@ export async function handleBridgeRun(
|
|||||||
emit,
|
emit,
|
||||||
sessionMap,
|
sessionMap,
|
||||||
{ model: resolvedModel, provider: resolvedProvider },
|
{ model: resolvedModel, provider: resolvedProvider },
|
||||||
async (messages) => {
|
async (_messages, localMessageTokens) => {
|
||||||
const cachedOverhead = getCachedBridgeContextOverhead(state)
|
const fixedContextTokens = await ensureBridgeFixedContext({
|
||||||
if (cachedOverhead != null) {
|
sessionId: session_id,
|
||||||
const messageUsage = estimateUsageTokensFromMessages(messages)
|
|
||||||
return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens
|
|
||||||
}
|
|
||||||
const estimate = await bridge.contextEstimate(
|
|
||||||
session_id,
|
|
||||||
messages,
|
|
||||||
fullInstructions,
|
|
||||||
profile,
|
profile,
|
||||||
{ model: resolvedModel, provider: resolvedProvider },
|
model: resolvedModel,
|
||||||
)
|
provider: resolvedProvider,
|
||||||
cacheBridgeContext(state, estimate)
|
instructions: fullInstructions,
|
||||||
|
state,
|
||||||
|
bridge,
|
||||||
|
refresh: true,
|
||||||
|
})
|
||||||
|
const contextTokens = fixedContextTokens == null
|
||||||
|
? localMessageTokens
|
||||||
|
: fixedContextTokens + localMessageTokens
|
||||||
bridgeLogger.info({
|
bridgeLogger.info({
|
||||||
sessionId: session_id,
|
sessionId: session_id,
|
||||||
profile,
|
profile,
|
||||||
model: resolvedModel,
|
model: resolvedModel,
|
||||||
provider: resolvedProvider,
|
provider: resolvedProvider,
|
||||||
messages: estimate.message_count,
|
fixedContextTokens,
|
||||||
toolCount: estimate.tool_count,
|
messageTokens: localMessageTokens,
|
||||||
systemPromptChars: estimate.system_prompt_chars,
|
contextTokens,
|
||||||
fixedContextTokens: estimate.fixed_context_tokens,
|
}, '[chat-run-socket] local context estimate')
|
||||||
fullContextTokens: estimate.token_count,
|
return contextTokens
|
||||||
}, '[chat-run-socket] full context estimate')
|
|
||||||
return estimate.token_count
|
|
||||||
},
|
},
|
||||||
|
currentInputTokens,
|
||||||
)
|
)
|
||||||
const bridgeHistory = history
|
const bridgeHistory = history
|
||||||
|
|
||||||
@@ -349,6 +410,8 @@ export async function handleBridgeRun(
|
|||||||
dequeueNextQueuedRun,
|
dequeueNextQueuedRun,
|
||||||
fullInstructions,
|
fullInstructions,
|
||||||
{ model: resolvedModel, provider: resolvedProvider },
|
{ model: resolvedModel, provider: resolvedProvider },
|
||||||
|
currentInputTokens,
|
||||||
|
shouldPersistUserMessage && displayRole === 'user',
|
||||||
data.model_groups,
|
data.model_groups,
|
||||||
)
|
)
|
||||||
if (chunk.done) break
|
if (chunk.done) break
|
||||||
@@ -417,7 +480,15 @@ async function refreshFinalContextUsage(args: {
|
|||||||
)
|
)
|
||||||
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
|
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
|
||||||
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
|
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
|
||||||
if (getCachedBridgeContextOverhead(args.state) != null) {
|
await ensureBridgeFixedContext({
|
||||||
|
sessionId: args.sessionId,
|
||||||
|
profile: args.profile,
|
||||||
|
model: args.model,
|
||||||
|
provider: args.provider,
|
||||||
|
instructions: args.instructions,
|
||||||
|
state: args.state,
|
||||||
|
bridge: args.bridge,
|
||||||
|
})
|
||||||
const contextTokens = updateMessageContextTokenUsage(
|
const contextTokens = updateMessageContextTokenUsage(
|
||||||
args.sessionId,
|
args.sessionId,
|
||||||
args.state,
|
args.state,
|
||||||
@@ -433,45 +504,44 @@ async function refreshFinalContextUsage(args: {
|
|||||||
messages: finalHistory.length,
|
messages: finalHistory.length,
|
||||||
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
|
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
|
||||||
messageTokens: finalMessageTokens,
|
messageTokens: finalMessageTokens,
|
||||||
fullContextTokens: contextTokens,
|
contextTokens,
|
||||||
}, '[chat-run-socket] final cached context estimate')
|
}, '[chat-run-socket] final local context estimate')
|
||||||
return contextTokens
|
|
||||||
}
|
|
||||||
const estimate = await args.bridge.contextEstimate(
|
|
||||||
args.sessionId,
|
|
||||||
finalHistory,
|
|
||||||
args.instructions,
|
|
||||||
args.profile,
|
|
||||||
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
|
|
||||||
)
|
|
||||||
cacheBridgeContext(args.state, estimate)
|
|
||||||
const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0
|
|
||||||
? Math.floor(estimate.token_count)
|
|
||||||
: undefined
|
|
||||||
if (contextTokens == null) return args.state.contextTokens
|
|
||||||
|
|
||||||
updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage)
|
|
||||||
bridgeLogger.info({
|
|
||||||
sessionId: args.sessionId,
|
|
||||||
profile: args.profile,
|
|
||||||
model: args.model,
|
|
||||||
provider: args.provider,
|
|
||||||
messages: estimate.message_count,
|
|
||||||
toolCount: estimate.tool_count,
|
|
||||||
systemPromptChars: estimate.system_prompt_chars,
|
|
||||||
fullContextTokens: contextTokens,
|
|
||||||
}, '[chat-run-socket] final full context estimate')
|
|
||||||
return contextTokens
|
return contextTokens
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
bridgeLogger.warn({
|
bridgeLogger.warn({
|
||||||
err: err instanceof Error ? { message: err.message, name: err.name } : err,
|
err: err instanceof Error ? { message: err.message, name: err.name } : err,
|
||||||
sessionId: args.sessionId,
|
sessionId: args.sessionId,
|
||||||
profile: args.profile,
|
profile: args.profile,
|
||||||
}, '[chat-run-socket] final full context estimate failed')
|
}, '[chat-run-socket] final local context estimate failed')
|
||||||
return args.state.contextTokens
|
return args.state.contextTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function estimateSnapshotAwareMessageTokens(args: {
|
||||||
|
sessionId: string
|
||||||
|
profile: string
|
||||||
|
model?: string | null
|
||||||
|
provider?: string | null
|
||||||
|
currentInputTokens?: number
|
||||||
|
currentInputIncludedInDb?: boolean
|
||||||
|
}): Promise<{ messageTokens: number; messages: number }> {
|
||||||
|
const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
|
||||||
|
const snapshotHistory = await buildSnapshotAwareHistory(
|
||||||
|
args.sessionId,
|
||||||
|
args.profile,
|
||||||
|
dbHistory,
|
||||||
|
{ model: args.model, provider: args.provider },
|
||||||
|
)
|
||||||
|
const usage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||||
|
const extraInputTokens = args.currentInputIncludedInDb
|
||||||
|
? 0
|
||||||
|
: finiteToken(args.currentInputTokens) ?? 0
|
||||||
|
return {
|
||||||
|
messageTokens: usage.inputTokens + usage.outputTokens + extraInputTokens,
|
||||||
|
messages: snapshotHistory.length,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function applyBridgeChunkAsync(
|
async function applyBridgeChunkAsync(
|
||||||
nsp: ReturnType<Server['of']>,
|
nsp: ReturnType<Server['of']>,
|
||||||
socket: Socket,
|
socket: Socket,
|
||||||
@@ -486,6 +556,8 @@ async function applyBridgeChunkAsync(
|
|||||||
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
|
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
|
||||||
instructions: string,
|
instructions: string,
|
||||||
modelContext: { model?: string | null; provider?: string | null },
|
modelContext: { model?: string | null; provider?: string | null },
|
||||||
|
currentInputTokens = 0,
|
||||||
|
currentInputIncludedInDb = true,
|
||||||
modelGroups?: RunModelGroup[],
|
modelGroups?: RunModelGroup[],
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
if (state.activeRunMarker !== runMarker) {
|
if (state.activeRunMarker !== runMarker) {
|
||||||
@@ -505,11 +577,19 @@ async function applyBridgeChunkAsync(
|
|||||||
if (evType === 'bridge.context.ready') {
|
if (evType === 'bridge.context.ready') {
|
||||||
cacheBridgeContext(state, ev)
|
cacheBridgeContext(state, ev)
|
||||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||||
|
const snapshotAware = await estimateSnapshotAwareMessageTokens({
|
||||||
|
sessionId,
|
||||||
|
profile,
|
||||||
|
model: modelContext.model,
|
||||||
|
provider: modelContext.provider,
|
||||||
|
currentInputTokens,
|
||||||
|
currentInputIncludedInDb,
|
||||||
|
})
|
||||||
updateMessageContextTokenUsage(
|
updateMessageContextTokenUsage(
|
||||||
sessionId,
|
sessionId,
|
||||||
state,
|
state,
|
||||||
emit,
|
emit,
|
||||||
usage.inputTokens + usage.outputTokens,
|
snapshotAware.messageTokens,
|
||||||
usage,
|
usage,
|
||||||
)
|
)
|
||||||
} else if (evType === 'tool.started') {
|
} else if (evType === 'tool.started') {
|
||||||
@@ -646,17 +726,22 @@ async function applyBridgeChunkAsync(
|
|||||||
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
|
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||||
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
|
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
|
||||||
const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens
|
const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens
|
||||||
const tokenCount = typeof ev.approx_tokens === 'number' && Number.isFinite(ev.approx_tokens) && ev.approx_tokens > 0
|
const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||||
? ev.approx_tokens
|
? Math.floor(currentInputTokens)
|
||||||
: messageOnlyTokens
|
: 0
|
||||||
|
const runMessageTokens = messageOnlyTokens + runInputTokens
|
||||||
|
const tokenCount = contextTokensWithCachedOverhead(state, runMessageTokens)
|
||||||
bridgeLogger.info({
|
bridgeLogger.info({
|
||||||
sessionId,
|
sessionId,
|
||||||
profile,
|
profile,
|
||||||
bridgeMessages: ev.message_count,
|
bridgeMessages: ev.message_count,
|
||||||
dbMessages: bridgeHistory.length,
|
dbMessages: bridgeHistory.length,
|
||||||
messageOnlyTokens,
|
messageOnlyTokens,
|
||||||
fullContextTokens: tokenCount,
|
currentInputTokens: runInputTokens,
|
||||||
source: typeof ev.approx_tokens === 'number' ? 'bridge' : 'message-only-fallback',
|
fixedContextTokens: state.bridgeContext?.fixedContextTokens,
|
||||||
|
contextTokens: tokenCount,
|
||||||
|
bridgeApproxTokens: ev.approx_tokens,
|
||||||
|
source: 'local',
|
||||||
}, '[chat-run-socket] bridge compression token estimate')
|
}, '[chat-run-socket] bridge compression token estimate')
|
||||||
const payload = {
|
const payload = {
|
||||||
event: 'compression.started',
|
event: 'compression.started',
|
||||||
@@ -674,7 +759,7 @@ async function applyBridgeChunkAsync(
|
|||||||
sessionId,
|
sessionId,
|
||||||
profile,
|
profile,
|
||||||
ev.messages as ChatMessage[],
|
ev.messages as ChatMessage[],
|
||||||
typeof ev.approx_tokens === 'number' ? ev.approx_tokens : undefined,
|
tokenCount,
|
||||||
)
|
)
|
||||||
state.bridgeCompressionResults = state.bridgeCompressionResults || {}
|
state.bridgeCompressionResults = state.bridgeCompressionResults || {}
|
||||||
state.bridgeCompressionResults[String(ev.request_id)] = compressed
|
state.bridgeCompressionResults[String(ev.request_id)] = compressed
|
||||||
@@ -689,11 +774,16 @@ async function applyBridgeChunkAsync(
|
|||||||
const compressionResult = ev.request_id
|
const compressionResult = ev.request_id
|
||||||
? state.bridgeCompressionResults?.[String(ev.request_id)]
|
? state.bridgeCompressionResults?.[String(ev.request_id)]
|
||||||
: undefined
|
: undefined
|
||||||
const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens)
|
|
||||||
const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
|
const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
|
||||||
const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null
|
const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||||
? contextTokensWithCachedOverhead(state, messageAfterTokens)
|
? Math.floor(currentInputTokens)
|
||||||
: bridgeAfterContextTokens ?? messageAfterTokens
|
: 0
|
||||||
|
const messageAfterTokensWithInput = messageAfterTokens != null
|
||||||
|
? messageAfterTokens + runInputTokens
|
||||||
|
: undefined
|
||||||
|
const afterContextTokens = messageAfterTokensWithInput != null
|
||||||
|
? contextTokensWithCachedOverhead(state, messageAfterTokensWithInput)
|
||||||
|
: undefined
|
||||||
const payload = {
|
const payload = {
|
||||||
event: 'compression.completed',
|
event: 'compression.completed',
|
||||||
run_id: chunk.run_id,
|
run_id: chunk.run_id,
|
||||||
@@ -703,7 +793,7 @@ async function applyBridgeChunkAsync(
|
|||||||
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
|
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
|
||||||
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
|
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
|
||||||
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
|
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
|
||||||
afterTokens: messageAfterTokens ?? bridgeAfterContextTokens,
|
afterTokens: messageAfterTokensWithInput,
|
||||||
contextTokens: afterContextTokens,
|
contextTokens: afterContextTokens,
|
||||||
summaryTokens: compressionResult?.summaryTokens,
|
summaryTokens: compressionResult?.summaryTokens,
|
||||||
verbatimCount: compressionResult?.verbatimCount,
|
verbatimCount: compressionResult?.verbatimCount,
|
||||||
@@ -716,10 +806,8 @@ async function applyBridgeChunkAsync(
|
|||||||
replaceState(sessionMap, sessionId, 'compression.completed', payload)
|
replaceState(sessionMap, sessionId, 'compression.completed', payload)
|
||||||
emit('compression.completed', payload)
|
emit('compression.completed', payload)
|
||||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||||
if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) {
|
if (messageAfterTokensWithInput != null) {
|
||||||
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage)
|
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokensWithInput, usage)
|
||||||
} else {
|
|
||||||
updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage)
|
|
||||||
}
|
}
|
||||||
} else if (evType === 'bridge.compression.failed') {
|
} else if (evType === 'bridge.compression.failed') {
|
||||||
const payload = {
|
const payload = {
|
||||||
|
|||||||
@@ -0,0 +1,96 @@
|
|||||||
|
// @vitest-environment jsdom
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { createPinia, setActivePinia } from 'pinia'
|
||||||
|
|
||||||
|
const chatApi = vi.hoisted(() => ({
|
||||||
|
resumeSession: vi.fn(),
|
||||||
|
registerSessionHandlers: vi.fn(),
|
||||||
|
unregisterSessionHandlers: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/api/hermes/chat', () => ({
|
||||||
|
startRunViaSocket: vi.fn(),
|
||||||
|
resumeSession: chatApi.resumeSession,
|
||||||
|
registerSessionHandlers: chatApi.registerSessionHandlers,
|
||||||
|
unregisterSessionHandlers: chatApi.unregisterSessionHandlers,
|
||||||
|
getChatRunSocket: vi.fn(() => ({ emit: vi.fn() })),
|
||||||
|
respondToolApproval: vi.fn(),
|
||||||
|
respondClarify: vi.fn(),
|
||||||
|
onPeerUserMessage: vi.fn(() => vi.fn()),
|
||||||
|
onSessionCommand: vi.fn(() => vi.fn()),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/api/client', () => ({
|
||||||
|
getActiveProfileName: () => 'default',
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/api/hermes/sessions', () => ({
|
||||||
|
deleteSession: vi.fn(),
|
||||||
|
fetchSession: vi.fn(),
|
||||||
|
fetchSessions: vi.fn(),
|
||||||
|
setSessionModel: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/api/hermes/download', () => ({
|
||||||
|
getDownloadUrl: (_path: string, name: string) => `/download/${name}`,
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/utils/completion-sound', () => ({
|
||||||
|
primeCompletionSound: vi.fn(),
|
||||||
|
playCompletionSound: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
import { useChatStore, type Session } from '@/stores/hermes/chat'
|
||||||
|
|
||||||
|
function makeSession(id: string): Session {
|
||||||
|
return {
|
||||||
|
id,
|
||||||
|
title: id,
|
||||||
|
messages: [],
|
||||||
|
createdAt: Date.now(),
|
||||||
|
updatedAt: Date.now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('chat store compression state', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.resetAllMocks()
|
||||||
|
setActivePinia(createPinia())
|
||||||
|
chatApi.resumeSession.mockImplementation((sessionId: string, onResumed: (data: any) => void) => {
|
||||||
|
onResumed({
|
||||||
|
session_id: sessionId,
|
||||||
|
messages: [],
|
||||||
|
isWorking: sessionId === 'session-1',
|
||||||
|
events: [],
|
||||||
|
})
|
||||||
|
return {} as any
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('does not show a background session compression indicator in the active session', async () => {
|
||||||
|
const store = useChatStore()
|
||||||
|
store.sessions = [makeSession('session-1'), makeSession('session-2')]
|
||||||
|
|
||||||
|
await store.switchSession('session-1')
|
||||||
|
const handlers = chatApi.registerSessionHandlers.mock.calls.find(call => call[0] === 'session-1')?.[1]
|
||||||
|
expect(handlers).toBeTruthy()
|
||||||
|
|
||||||
|
await store.switchSession('session-2')
|
||||||
|
handlers.onCompressionStarted({
|
||||||
|
event: 'compression.started',
|
||||||
|
session_id: 'session-1',
|
||||||
|
message_count: 6,
|
||||||
|
token_count: 1234,
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(store.activeSessionId).toBe('session-2')
|
||||||
|
expect(store.compressionState).toBeNull()
|
||||||
|
|
||||||
|
await store.switchSession('session-1')
|
||||||
|
expect(store.compressionState).toEqual(expect.objectContaining({
|
||||||
|
compressing: true,
|
||||||
|
messageCount: 6,
|
||||||
|
beforeTokens: 1234,
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -292,9 +292,11 @@ except RuntimeError as exc:
|
|||||||
assert "already running" in str(exc)
|
assert "already running" in str(exc)
|
||||||
|
|
||||||
class FakeWorker:
|
class FakeWorker:
|
||||||
def __init__(self, destroyed):
|
def __init__(self, destroyed, profile="default", key="default"):
|
||||||
self.running = True
|
self.running = True
|
||||||
self.destroyed = destroyed
|
self.destroyed = destroyed
|
||||||
|
self.profile = profile
|
||||||
|
self.key = key
|
||||||
self.requests = []
|
self.requests = []
|
||||||
self.stopped = False
|
self.stopped = False
|
||||||
|
|
||||||
@@ -310,28 +312,41 @@ broker = bridge.BridgeBroker("ipc:///tmp/unused.sock")
|
|||||||
profile_worker = FakeWorker(2)
|
profile_worker = FakeWorker(2)
|
||||||
broker._workers["default"] = profile_worker
|
broker._workers["default"] = profile_worker
|
||||||
broker._run_profile["run-session-a"] = "default"
|
broker._run_profile["run-session-a"] = "default"
|
||||||
|
broker._run_worker_key["run-session-a"] = "default"
|
||||||
broker._running_run_profile["run-session-a"] = "default"
|
broker._running_run_profile["run-session-a"] = "default"
|
||||||
|
broker._running_run_worker_key["run-session-a"] = "default"
|
||||||
broker._session_profile["session-a"] = "default"
|
broker._session_profile["session-a"] = "default"
|
||||||
|
broker._session_worker_key["session-a"] = "default"
|
||||||
broker._approval_profile["approval-a"] = "default"
|
broker._approval_profile["approval-a"] = "default"
|
||||||
|
broker._approval_worker_key["approval-a"] = "default"
|
||||||
broker._compression_profile["compression-a"] = "default"
|
broker._compression_profile["compression-a"] = "default"
|
||||||
|
broker._compression_worker_key["compression-a"] = "default"
|
||||||
|
|
||||||
destroy_profile_result = broker.handle({"action": "destroy_profile", "profile": "default"})
|
destroy_profile_result = broker.handle({"action": "destroy_profile", "profile": "default"})
|
||||||
assert destroy_profile_result == {"profile": "default", "destroyed": 2}
|
assert destroy_profile_result == {"profile": "default", "destroyed": 2}
|
||||||
assert profile_worker.stopped
|
assert profile_worker.stopped
|
||||||
assert "default" not in broker._workers
|
assert "default" not in broker._workers
|
||||||
assert broker._run_profile == {}
|
assert broker._run_profile == {}
|
||||||
|
assert broker._run_worker_key == {}
|
||||||
assert broker._running_run_profile == {}
|
assert broker._running_run_profile == {}
|
||||||
|
assert broker._running_run_worker_key == {}
|
||||||
assert broker._session_profile == {}
|
assert broker._session_profile == {}
|
||||||
|
assert broker._session_worker_key == {}
|
||||||
assert broker._approval_profile == {}
|
assert broker._approval_profile == {}
|
||||||
|
assert broker._approval_worker_key == {}
|
||||||
assert broker._compression_profile == {}
|
assert broker._compression_profile == {}
|
||||||
|
assert broker._compression_worker_key == {}
|
||||||
|
|
||||||
worker_a = FakeWorker(1)
|
worker_a = FakeWorker(1, "default", "a")
|
||||||
worker_b = FakeWorker(3)
|
worker_b = FakeWorker(3, "work", "b")
|
||||||
broker._workers["a"] = worker_a
|
broker._workers["a"] = worker_a
|
||||||
broker._workers["b"] = worker_b
|
broker._workers["b"] = worker_b
|
||||||
broker._run_profile["run-a"] = "a"
|
broker._run_profile["run-a"] = "default"
|
||||||
broker._running_run_profile["run-a"] = "a"
|
broker._run_worker_key["run-a"] = "a"
|
||||||
broker._session_profile["session-b"] = "b"
|
broker._running_run_profile["run-a"] = "default"
|
||||||
|
broker._running_run_worker_key["run-a"] = "a"
|
||||||
|
broker._session_profile["session-b"] = "work"
|
||||||
|
broker._session_worker_key["session-b"] = "b"
|
||||||
|
|
||||||
destroy_all_result = broker.handle({"action": "destroy_all"})
|
destroy_all_result = broker.handle({"action": "destroy_all"})
|
||||||
assert destroy_all_result == {"destroyed": 4}
|
assert destroy_all_result == {"destroyed": 4}
|
||||||
@@ -339,8 +354,11 @@ assert worker_a.stopped
|
|||||||
assert worker_b.stopped
|
assert worker_b.stopped
|
||||||
assert broker._workers == {}
|
assert broker._workers == {}
|
||||||
assert broker._run_profile == {}
|
assert broker._run_profile == {}
|
||||||
|
assert broker._run_worker_key == {}
|
||||||
assert broker._running_run_profile == {}
|
assert broker._running_run_profile == {}
|
||||||
|
assert broker._running_run_worker_key == {}
|
||||||
assert broker._session_profile == {}
|
assert broker._session_profile == {}
|
||||||
|
assert broker._session_worker_key == {}
|
||||||
`)
|
`)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -372,6 +390,69 @@ assert resp["running_sessions_by_profile"] == {"default": 1}
|
|||||||
`)
|
`)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('routes worker-keyed broker requests without stopping the worker on session destroy', () => {
|
||||||
|
runPython(String.raw`
|
||||||
|
${harness}
|
||||||
|
|
||||||
|
class RoutedWorker:
|
||||||
|
running = True
|
||||||
|
pid = 12345
|
||||||
|
endpoint = "ipc:///tmp/worker.sock"
|
||||||
|
last_used_at = 12.5
|
||||||
|
|
||||||
|
def __init__(self, profile, key):
|
||||||
|
self.profile = profile
|
||||||
|
self.key = key
|
||||||
|
self.requests = []
|
||||||
|
self.stopped = False
|
||||||
|
|
||||||
|
def request(self, req):
|
||||||
|
self.requests.append(req)
|
||||||
|
action = req.get("action")
|
||||||
|
if action == "chat":
|
||||||
|
return {"ok": True, "run_id": "run-compress", "session_id": req["session_id"], "status": "running"}
|
||||||
|
if action == "get_output":
|
||||||
|
return {"ok": True, "run_id": req["run_id"], "session_id": "compress-temp", "status": "complete", "done": True}
|
||||||
|
if action == "destroy":
|
||||||
|
return {"ok": True, "session_id": req["session_id"], "destroyed": True}
|
||||||
|
raise AssertionError(f"unexpected action: {action}")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.stopped = True
|
||||||
|
|
||||||
|
broker = bridge.BridgeBroker("ipc:///tmp/unused.sock")
|
||||||
|
worker = RoutedWorker("default", "default:compression:session-a")
|
||||||
|
broker._workers[worker.key] = worker
|
||||||
|
|
||||||
|
chat_resp = broker.handle({
|
||||||
|
"action": "chat",
|
||||||
|
"session_id": "compress-temp",
|
||||||
|
"profile": "default",
|
||||||
|
"worker_key": worker.key,
|
||||||
|
"message": "summarize",
|
||||||
|
})
|
||||||
|
assert chat_resp["run_id"] == "run-compress"
|
||||||
|
assert worker.requests[-1]["profile"] == "default"
|
||||||
|
assert "worker_key" not in worker.requests[-1]
|
||||||
|
|
||||||
|
broker.handle({"action": "get_output", "run_id": "run-compress"})
|
||||||
|
assert worker.requests[-1]["action"] == "get_output"
|
||||||
|
|
||||||
|
destroy_resp = broker.handle({
|
||||||
|
"action": "destroy",
|
||||||
|
"session_id": "compress-temp",
|
||||||
|
"profile": "default",
|
||||||
|
"worker_key": worker.key,
|
||||||
|
})
|
||||||
|
assert destroy_resp["destroyed"] is True
|
||||||
|
assert worker.requests[-1]["action"] == "destroy"
|
||||||
|
assert not worker.stopped
|
||||||
|
assert worker.key in broker._workers
|
||||||
|
assert "compress-temp" not in broker._session_profile
|
||||||
|
assert "compress-temp" not in broker._session_worker_key
|
||||||
|
`)
|
||||||
|
})
|
||||||
|
|
||||||
it('restores approval env and clears handlers when a run fails', () => {
|
it('restores approval env and clears handlers when a run fails', () => {
|
||||||
runPython(String.raw`
|
runPython(String.raw`
|
||||||
${harness}
|
${harness}
|
||||||
@@ -480,7 +561,7 @@ original_getpid = bridge.os.getpid
|
|||||||
try:
|
try:
|
||||||
bridge.subprocess.Popen = fake_popen
|
bridge.subprocess.Popen = fake_popen
|
||||||
bridge.os.getpid = lambda: 4242
|
bridge.os.getpid = lambda: 4242
|
||||||
proc_worker = bridge.WorkerProcess("default", "ipc:///tmp/worker.sock", "/agent", "/home")
|
proc_worker = bridge.WorkerProcess("default:compression:session-a", "default", "ipc:///tmp/worker.sock", "/agent", "/home")
|
||||||
proc_worker._pipe_stderr = lambda: None
|
proc_worker._pipe_stderr = lambda: None
|
||||||
proc_worker._wait_ready = lambda: None
|
proc_worker._wait_ready = lambda: None
|
||||||
proc_worker.start()
|
proc_worker.start()
|
||||||
|
|||||||
@@ -153,6 +153,42 @@ describe('ChatContextCompressor', () => {
|
|||||||
expect(saveCompressionSnapshotMock).toHaveBeenCalledWith('s1', 'compressed summary', 6, 10)
|
expect(saveCompressionSnapshotMock).toHaveBeenCalledWith('s1', 'compressed summary', 6, 10)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('routes summarization through the provided worker key and destroys only the temporary agent session', async () => {
|
||||||
|
const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor')
|
||||||
|
const compressor = new ChatContextCompressor({
|
||||||
|
config: { headMessageCount: 0, tailMessageCount: 1, summaryBudget: 1000 },
|
||||||
|
})
|
||||||
|
const messages = [
|
||||||
|
{ role: 'user', content: 'old context' },
|
||||||
|
{ role: 'assistant', content: 'old response' },
|
||||||
|
{ role: 'user', content: 'tail' },
|
||||||
|
]
|
||||||
|
getCompressionSnapshotMock.mockReturnValue(null)
|
||||||
|
bridgeRequestMock.mockResolvedValue({
|
||||||
|
status: 'completed',
|
||||||
|
result: { final_response: 'compressed summary' },
|
||||||
|
})
|
||||||
|
|
||||||
|
await compressor.compress(messages, 'http://upstream', undefined, 's1', {
|
||||||
|
profile: 'default',
|
||||||
|
workerKey: 'default:compression:s1',
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(bridgeRequestMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||||
|
action: 'chat',
|
||||||
|
profile: 'default',
|
||||||
|
worker_key: 'default:compression:s1',
|
||||||
|
wait: true,
|
||||||
|
}), expect.any(Object))
|
||||||
|
const compressSessionId = bridgeRequestMock.mock.calls[0][0].session_id
|
||||||
|
expect(String(compressSessionId)).toMatch(/^compress_/)
|
||||||
|
expect(bridgeDestroyMock).toHaveBeenCalledWith(
|
||||||
|
compressSessionId,
|
||||||
|
'default',
|
||||||
|
'default:compression:s1',
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
it('does not pre-prune tool results before sending them to the summarizer', async () => {
|
it('does not pre-prune tool results before sending them to the summarizer', async () => {
|
||||||
const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor')
|
const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor')
|
||||||
const compressor = new ChatContextCompressor({
|
const compressor = new ChatContextCompressor({
|
||||||
|
|||||||
@@ -127,8 +127,18 @@ describe('bridge run final context usage', () => {
|
|||||||
buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history)
|
buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history)
|
||||||
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 })
|
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 })
|
||||||
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 })
|
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 })
|
||||||
getCachedBridgeContextOverheadMock.mockReturnValue(undefined)
|
getCachedBridgeContextOverheadMock.mockImplementation((state: any) => {
|
||||||
contextTokensWithCachedOverheadMock.mockImplementation((_state: any, messageTokens: number) => messageTokens)
|
const fixed = state?.bridgeContext?.fixedContextTokens
|
||||||
|
return typeof fixed === 'number' ? fixed : undefined
|
||||||
|
})
|
||||||
|
contextTokensWithCachedOverheadMock.mockImplementation((state: any, messageTokens: number) => {
|
||||||
|
const fixed = state?.bridgeContext?.fixedContextTokens
|
||||||
|
return typeof fixed === 'number' ? fixed + messageTokens : messageTokens
|
||||||
|
})
|
||||||
|
updateMessageContextTokenUsageMock.mockImplementation((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => {
|
||||||
|
const contextTokens = contextTokensWithCachedOverheadMock(state, messageTokens)
|
||||||
|
return updateContextTokenUsageMock(sid, state, emit, contextTokens, usage)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('refreshes full context tokens when a bridge run completes', async () => {
|
it('refreshes full context tokens when a bridge run completes', async () => {
|
||||||
@@ -141,6 +151,7 @@ describe('bridge run final context usage', () => {
|
|||||||
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
|
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
|
||||||
contextEstimate: vi.fn().mockResolvedValue({
|
contextEstimate: vi.fn().mockResolvedValue({
|
||||||
token_count: 12345,
|
token_count: 12345,
|
||||||
|
fixed_context_tokens: 12327,
|
||||||
message_count: 2,
|
message_count: 2,
|
||||||
tool_count: 4,
|
tool_count: 4,
|
||||||
system_prompt_chars: 13,
|
system_prompt_chars: 13,
|
||||||
@@ -165,10 +176,7 @@ describe('bridge run final context usage', () => {
|
|||||||
|
|
||||||
expect(bridge.contextEstimate).toHaveBeenCalledWith(
|
expect(bridge.contextEstimate).toHaveBeenCalledWith(
|
||||||
'session-1',
|
'session-1',
|
||||||
[
|
[],
|
||||||
{ role: 'user', content: 'hello' },
|
|
||||||
{ role: 'assistant', content: 'done' },
|
|
||||||
],
|
|
||||||
expect.stringContaining('[Current Hermes profile: default]'),
|
expect.stringContaining('[Current Hermes profile: default]'),
|
||||||
'default',
|
'default',
|
||||||
{ model: 'gpt-test', provider: 'openai' },
|
{ model: 'gpt-test', provider: 'openai' },
|
||||||
@@ -326,14 +334,22 @@ describe('bridge run final context usage', () => {
|
|||||||
const nsp = makeNamespace(emit)
|
const nsp = makeNamespace(emit)
|
||||||
const socket = makeSocket()
|
const socket = makeSocket()
|
||||||
const state = makeState()
|
const state = makeState()
|
||||||
state.bridgeContext = { fixedContextTokens: 20_000 }
|
|
||||||
const sessionMap = new Map([['session-1', state]])
|
const sessionMap = new Map([['session-1', state]])
|
||||||
getCachedBridgeContextOverheadMock.mockReturnValue(20_000)
|
|
||||||
updateMessageContextTokenUsageMock.mockImplementation((sid: string, targetState: any, targetEmit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, targetState, targetEmit, 20_000 + messageTokens, usage))
|
|
||||||
const bridge = {
|
const bridge = {
|
||||||
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
|
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
|
||||||
contextEstimate: vi.fn(),
|
contextEstimate: vi.fn(),
|
||||||
streamOutput: vi.fn(async function* () {
|
streamOutput: vi.fn(async function* () {
|
||||||
|
yield {
|
||||||
|
run_id: 'run-1',
|
||||||
|
done: false,
|
||||||
|
status: 'running',
|
||||||
|
events: [{
|
||||||
|
event: 'bridge.context.ready',
|
||||||
|
fixed_context_tokens: 20_000,
|
||||||
|
system_prompt_tokens: 3_000,
|
||||||
|
tool_tokens: 17_000,
|
||||||
|
}],
|
||||||
|
}
|
||||||
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
|
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
|
||||||
}),
|
}),
|
||||||
} as any
|
} as any
|
||||||
@@ -365,6 +381,80 @@ describe('bridge run final context usage', () => {
|
|||||||
}))
|
}))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('keeps bridge context ready updates on the snapshot-aware token baseline', async () => {
|
||||||
|
const emit = vi.fn()
|
||||||
|
const nsp = makeNamespace(emit)
|
||||||
|
const socket = makeSocket()
|
||||||
|
const state = makeState()
|
||||||
|
const sessionMap = new Map([['session-1', state]])
|
||||||
|
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 28_000, outputTokens: 0 })
|
||||||
|
buildDbHistoryMock.mockResolvedValue([
|
||||||
|
{ role: 'user', content: 'very large old context' },
|
||||||
|
{ role: 'assistant', content: 'large old response' },
|
||||||
|
{ role: 'user', content: 'hello' },
|
||||||
|
])
|
||||||
|
buildSnapshotAwareHistoryMock.mockResolvedValue([
|
||||||
|
{ role: 'user', content: '[Previous context summary]\n\nsmall summary' },
|
||||||
|
{ role: 'user', content: 'hello' },
|
||||||
|
])
|
||||||
|
estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => {
|
||||||
|
if (messages?.[0]?.content?.includes('small summary')) {
|
||||||
|
return { inputTokens: 9_000, outputTokens: 0 }
|
||||||
|
}
|
||||||
|
return { inputTokens: 28_000, outputTokens: 0 }
|
||||||
|
})
|
||||||
|
const bridge = {
|
||||||
|
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
|
||||||
|
contextEstimate: vi.fn(),
|
||||||
|
streamOutput: vi.fn(async function* () {
|
||||||
|
yield {
|
||||||
|
run_id: 'run-1',
|
||||||
|
done: false,
|
||||||
|
status: 'running',
|
||||||
|
events: [{
|
||||||
|
event: 'bridge.context.ready',
|
||||||
|
fixed_context_tokens: 10_000,
|
||||||
|
system_prompt_tokens: 2_000,
|
||||||
|
tool_tokens: 8_000,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
|
||||||
|
}),
|
||||||
|
} as any
|
||||||
|
|
||||||
|
const { handleBridgeRun } = await import('../../packages/server/src/services/hermes/run-chat/handle-bridge-run')
|
||||||
|
await handleBridgeRun(
|
||||||
|
nsp,
|
||||||
|
socket,
|
||||||
|
{ input: 'hello', session_id: 'session-1' },
|
||||||
|
'default',
|
||||||
|
sessionMap,
|
||||||
|
bridge,
|
||||||
|
false,
|
||||||
|
vi.fn(),
|
||||||
|
vi.fn(),
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
|
||||||
|
'session-1',
|
||||||
|
state,
|
||||||
|
expect.any(Function),
|
||||||
|
9_000,
|
||||||
|
{ inputTokens: 28_000, outputTokens: 0 },
|
||||||
|
)
|
||||||
|
expect(updateMessageContextTokenUsageMock).not.toHaveBeenCalledWith(
|
||||||
|
'session-1',
|
||||||
|
state,
|
||||||
|
expect.any(Function),
|
||||||
|
28_000,
|
||||||
|
{ inputTokens: 28_000, outputTokens: 0 },
|
||||||
|
)
|
||||||
|
expect(state.contextTokens).toBe(19_000)
|
||||||
|
expect(emit).toHaveBeenCalledWith('run.completed', expect.objectContaining({
|
||||||
|
contextTokens: 19_000,
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
it('persists pending tool marker text before a bridge run completes', async () => {
|
it('persists pending tool marker text before a bridge run completes', async () => {
|
||||||
const emit = vi.fn()
|
const emit = vi.fn()
|
||||||
const nsp = makeNamespace(emit)
|
const nsp = makeNamespace(emit)
|
||||||
@@ -502,6 +592,7 @@ describe('bridge run final context usage', () => {
|
|||||||
chat: vi.fn().mockRejectedValue(new Error('bridge timeout')),
|
chat: vi.fn().mockRejectedValue(new Error('bridge timeout')),
|
||||||
contextEstimate: vi.fn().mockResolvedValue({
|
contextEstimate: vi.fn().mockResolvedValue({
|
||||||
token_count: 54321,
|
token_count: 54321,
|
||||||
|
fixed_context_tokens: 54303,
|
||||||
message_count: 1,
|
message_count: 1,
|
||||||
tool_count: 4,
|
tool_count: 4,
|
||||||
system_prompt_chars: 13,
|
system_prompt_chars: 13,
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ describe('run chat compression trigger', () => {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('uses full context estimates for compression threshold decisions', async () => {
|
it('uses local context estimates for compression threshold decisions', async () => {
|
||||||
const messages = Array.from({ length: 10 }, (_, index) => ({
|
const messages = Array.from({ length: 10 }, (_, index) => ({
|
||||||
id: index + 1,
|
id: index + 1,
|
||||||
session_id: 'session-1',
|
session_id: 'session-1',
|
||||||
@@ -191,7 +191,7 @@ describe('run chat compression trigger', () => {
|
|||||||
getSessionDetailMock.mockReturnValue({ messages })
|
getSessionDetailMock.mockReturnValue({ messages })
|
||||||
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 1_000, outputTokens: 0 })
|
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 1_000, outputTokens: 0 })
|
||||||
compressorCompressMock.mockResolvedValue({
|
compressorCompressMock.mockResolvedValue({
|
||||||
messages: [{ role: 'user', content: 'compressed by full context estimate' }],
|
messages: [{ role: 'user', content: 'compressed by local context estimate' }],
|
||||||
meta: {
|
meta: {
|
||||||
compressed: true,
|
compressed: true,
|
||||||
llmCompressed: true,
|
llmCompressed: true,
|
||||||
@@ -215,7 +215,7 @@ describe('run chat compression trigger', () => {
|
|||||||
vi.fn(async () => 120_000),
|
vi.fn(async () => 120_000),
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(history).toEqual([{ role: 'user', content: 'compressed by full context estimate' }])
|
expect(history).toEqual([{ role: 'user', content: 'compressed by local context estimate' }])
|
||||||
expect(compressorCompressMock).toHaveBeenCalledTimes(1)
|
expect(compressorCompressMock).toHaveBeenCalledTimes(1)
|
||||||
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
|
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
|
||||||
'session-1',
|
'session-1',
|
||||||
@@ -226,7 +226,7 @@ describe('run chat compression trigger', () => {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('emits full context token usage when the full estimate is under threshold', async () => {
|
it('emits local context token usage when the local estimate is under threshold', async () => {
|
||||||
const messages = Array.from({ length: 10 }, (_, index) => ({
|
const messages = Array.from({ length: 10 }, (_, index) => ({
|
||||||
id: index + 1,
|
id: index + 1,
|
||||||
session_id: 'session-1',
|
session_id: 'session-1',
|
||||||
@@ -257,7 +257,10 @@ describe('run chat compression trigger', () => {
|
|||||||
)
|
)
|
||||||
|
|
||||||
expect(history).toHaveLength(9)
|
expect(history).toHaveLength(9)
|
||||||
expect(contextTokenEstimator).toHaveBeenCalledWith(expect.arrayContaining([{ role: 'user', content: 'message 0' }]))
|
expect(contextTokenEstimator).toHaveBeenCalledWith(
|
||||||
|
expect.arrayContaining([{ role: 'user', content: 'message 0' }]),
|
||||||
|
1_900,
|
||||||
|
)
|
||||||
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
|
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
|
||||||
event: 'usage.updated',
|
event: 'usage.updated',
|
||||||
session_id: 'session-1',
|
session_id: 'session-1',
|
||||||
@@ -268,6 +271,108 @@ describe('run chat compression trigger', () => {
|
|||||||
expect(compressorCompressMock).not.toHaveBeenCalled()
|
expect(compressorCompressMock).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('includes current input tokens when estimating snapshot-aware context', async () => {
|
||||||
|
const messages = Array.from({ length: 10 }, (_, index) => ({
|
||||||
|
id: index + 1,
|
||||||
|
session_id: 'session-1',
|
||||||
|
role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant',
|
||||||
|
content: `message ${index}`,
|
||||||
|
timestamp: index + 1,
|
||||||
|
tool_call_id: null,
|
||||||
|
tool_calls: null,
|
||||||
|
tool_name: null,
|
||||||
|
finish_reason: null,
|
||||||
|
reasoning_content: null,
|
||||||
|
}))
|
||||||
|
getSessionDetailMock.mockReturnValue({ messages })
|
||||||
|
getCompressionSnapshotMock.mockReturnValue({
|
||||||
|
summary: 'previous summary',
|
||||||
|
lastMessageIndex: 4,
|
||||||
|
messageCountAtTime: 5,
|
||||||
|
})
|
||||||
|
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 10, outputTokens: 0 })
|
||||||
|
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 1_000, outputTokens: 0 })
|
||||||
|
const emit = vi.fn()
|
||||||
|
const contextTokenEstimator = vi.fn(async (_messages, messageTokens: number) => 20_000 + messageTokens)
|
||||||
|
|
||||||
|
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
|
||||||
|
await buildCompressedHistory(
|
||||||
|
'session-1',
|
||||||
|
'default',
|
||||||
|
'http://upstream',
|
||||||
|
undefined,
|
||||||
|
emit,
|
||||||
|
new Map(),
|
||||||
|
{},
|
||||||
|
contextTokenEstimator,
|
||||||
|
700,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(contextTokenEstimator).toHaveBeenCalledWith(expect.any(Array), 1_700)
|
||||||
|
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
|
||||||
|
contextTokens: 21_700,
|
||||||
|
}))
|
||||||
|
expect(compressorCompressMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('keeps current input tokens in the compression completed context total', async () => {
|
||||||
|
const messages = Array.from({ length: 10 }, (_, index) => ({
|
||||||
|
id: index + 1,
|
||||||
|
session_id: 'session-1',
|
||||||
|
role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant',
|
||||||
|
content: `message ${index}`,
|
||||||
|
timestamp: index + 1,
|
||||||
|
tool_call_id: null,
|
||||||
|
tool_calls: null,
|
||||||
|
tool_name: null,
|
||||||
|
finish_reason: null,
|
||||||
|
reasoning_content: null,
|
||||||
|
}))
|
||||||
|
getSessionDetailMock.mockReturnValue({ messages })
|
||||||
|
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 100, outputTokens: 0 })
|
||||||
|
estimateUsageTokensFromMessagesMock.mockImplementation((items: any[]) => {
|
||||||
|
if (items?.[0]?.content === 'compressed result') return { inputTokens: 1_000, outputTokens: 0 }
|
||||||
|
return { inputTokens: 100, outputTokens: 0 }
|
||||||
|
})
|
||||||
|
compressorCompressMock.mockResolvedValue({
|
||||||
|
messages: [{ role: 'user', content: 'compressed result' }],
|
||||||
|
meta: {
|
||||||
|
compressed: true,
|
||||||
|
llmCompressed: true,
|
||||||
|
totalMessages: 9,
|
||||||
|
summaryTokenEstimate: 1,
|
||||||
|
verbatimCount: 0,
|
||||||
|
compressedStartIndex: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
const emit = vi.fn()
|
||||||
|
|
||||||
|
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
|
||||||
|
await buildCompressedHistory(
|
||||||
|
'session-1',
|
||||||
|
'default',
|
||||||
|
'http://upstream',
|
||||||
|
undefined,
|
||||||
|
emit,
|
||||||
|
new Map(),
|
||||||
|
{},
|
||||||
|
vi.fn(async () => 120_000),
|
||||||
|
700,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
|
||||||
|
'session-1',
|
||||||
|
expect.any(Object),
|
||||||
|
emit,
|
||||||
|
1_700,
|
||||||
|
{ inputTokens: 100, outputTokens: 0 },
|
||||||
|
)
|
||||||
|
expect(emit).toHaveBeenCalledWith('compression.completed', expect.objectContaining({
|
||||||
|
afterTokens: 1_700,
|
||||||
|
contextTokens: 1_700,
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
it('throws when fixed prompt and tool schemas exceed threshold before any history exists', async () => {
|
it('throws when fixed prompt and tool schemas exceed threshold before any history exists', async () => {
|
||||||
getSessionDetailMock.mockReturnValue({ messages: [] })
|
getSessionDetailMock.mockReturnValue({ messages: [] })
|
||||||
const emit = vi.fn()
|
const emit = vi.fn()
|
||||||
|
|||||||
@@ -0,0 +1,129 @@
|
|||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
const getSessionMock = vi.fn()
|
||||||
|
const getSessionDetailPaginatedMock = vi.fn()
|
||||||
|
const getCompressionSnapshotMock = vi.fn()
|
||||||
|
const estimateUsageTokensFromMessagesMock = vi.fn()
|
||||||
|
const buildDbHistoryMock = vi.fn()
|
||||||
|
const buildSnapshotAwareHistoryMock = vi.fn()
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/db/hermes/session-store', () => ({
|
||||||
|
getSession: getSessionMock,
|
||||||
|
createSession: vi.fn(),
|
||||||
|
addMessage: vi.fn(),
|
||||||
|
updateSessionStats: vi.fn(),
|
||||||
|
getSessionDetailPaginated: getSessionDetailPaginatedMock,
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/db/hermes/usage-store', () => ({
|
||||||
|
updateUsage: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/db/hermes/compression-snapshot', () => ({
|
||||||
|
getCompressionSnapshot: getCompressionSnapshotMock,
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/lib/context-compressor', () => ({
|
||||||
|
SUMMARY_PREFIX: '[Previous context summary]',
|
||||||
|
countTokens: vi.fn(() => 0),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/services/logger', () => ({
|
||||||
|
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() },
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () => ({
|
||||||
|
buildCompressedHistory: vi.fn(),
|
||||||
|
buildDbHistory: buildDbHistoryMock,
|
||||||
|
buildSnapshotAwareHistory: buildSnapshotAwareHistoryMock,
|
||||||
|
getOrCreateSession: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({
|
||||||
|
calcAndUpdateUsage: vi.fn(),
|
||||||
|
estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock,
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/services/hermes/run-chat/message-format', () => ({
|
||||||
|
convertHistoryFormat: vi.fn((messages: any[]) => messages),
|
||||||
|
handleMessage: vi.fn((messages: any[]) => messages),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/services/hermes/run-chat/content-blocks', () => ({
|
||||||
|
contentBlocksToString: vi.fn((value: any) => String(value || '')),
|
||||||
|
extractTextForPreview: vi.fn((value: any) => String(value || '')),
|
||||||
|
isContentBlockArray: vi.fn(() => false),
|
||||||
|
convertContentBlocks: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/lib/llm-prompt', () => ({
|
||||||
|
getSystemPrompt: vi.fn(() => 'system prompt'),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/services/hermes/run-chat/sse-utils', () => ({
|
||||||
|
readSseFrames: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/services/hermes/run-chat/response-utils', () => ({
|
||||||
|
extractResponseText: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../packages/server/src/services/hermes/run-chat/response-stream', () => ({
|
||||||
|
applyResponseStreamEvent: vi.fn(),
|
||||||
|
flushResponseRunToDb: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('loadSessionStateFromDb', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
getSessionMock.mockReturnValue({
|
||||||
|
id: 'session-1',
|
||||||
|
profile: 'default',
|
||||||
|
model: 'gpt-test',
|
||||||
|
provider: 'openai',
|
||||||
|
})
|
||||||
|
getSessionDetailPaginatedMock.mockReturnValue({
|
||||||
|
messages: [
|
||||||
|
{ role: 'user', content: 'old large context' },
|
||||||
|
{ role: 'assistant', content: 'old large answer' },
|
||||||
|
{ role: 'user', content: 'new tail' },
|
||||||
|
],
|
||||||
|
})
|
||||||
|
getCompressionSnapshotMock.mockReturnValue({
|
||||||
|
summary: 'small summary',
|
||||||
|
lastMessageIndex: 0,
|
||||||
|
messageCountAtTime: 1,
|
||||||
|
})
|
||||||
|
buildDbHistoryMock.mockResolvedValue([
|
||||||
|
{ role: 'user', content: 'old large context' },
|
||||||
|
{ role: 'assistant', content: 'old large answer' },
|
||||||
|
{ role: 'user', content: 'new tail' },
|
||||||
|
])
|
||||||
|
buildSnapshotAwareHistoryMock.mockResolvedValue([
|
||||||
|
{ role: 'user', content: '[Previous context summary]\n\nsmall summary' },
|
||||||
|
{ role: 'user', content: 'new tail' },
|
||||||
|
])
|
||||||
|
estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => {
|
||||||
|
if (messages?.[0]?.content?.includes('small summary')) {
|
||||||
|
return { inputTokens: 9_000, outputTokens: 0 }
|
||||||
|
}
|
||||||
|
return { inputTokens: 28_000, outputTokens: 0 }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('hydrates contextTokens from the same snapshot-aware history used for bridge runs', async () => {
|
||||||
|
const { loadSessionStateFromDb } = await import('../../packages/server/src/services/hermes/run-chat/handle-api-run')
|
||||||
|
|
||||||
|
const state = await loadSessionStateFromDb('session-1', new Map())
|
||||||
|
|
||||||
|
expect(buildSnapshotAwareHistoryMock).toHaveBeenCalledWith(
|
||||||
|
'session-1',
|
||||||
|
'default',
|
||||||
|
expect.any(Array),
|
||||||
|
{ model: 'gpt-test', provider: 'openai' },
|
||||||
|
)
|
||||||
|
expect(state.inputTokens).toBe(28_000)
|
||||||
|
expect(state.outputTokens).toBe(0)
|
||||||
|
expect(state.contextTokens).toBe(9_000)
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user