fix context token resume (#1039)

This commit is contained in:
ekko
2026-05-26 16:32:07 +08:00
committed by GitHub
parent e686f0277a
commit ad1cab277a
13 changed files with 959 additions and 203 deletions
@@ -292,9 +292,11 @@ except RuntimeError as exc:
assert "already running" in str(exc)
class FakeWorker:
def __init__(self, destroyed):
def __init__(self, destroyed, profile="default", key="default"):
self.running = True
self.destroyed = destroyed
self.profile = profile
self.key = key
self.requests = []
self.stopped = False
@@ -310,28 +312,41 @@ broker = bridge.BridgeBroker("ipc:///tmp/unused.sock")
profile_worker = FakeWorker(2)
broker._workers["default"] = profile_worker
broker._run_profile["run-session-a"] = "default"
broker._run_worker_key["run-session-a"] = "default"
broker._running_run_profile["run-session-a"] = "default"
broker._running_run_worker_key["run-session-a"] = "default"
broker._session_profile["session-a"] = "default"
broker._session_worker_key["session-a"] = "default"
broker._approval_profile["approval-a"] = "default"
broker._approval_worker_key["approval-a"] = "default"
broker._compression_profile["compression-a"] = "default"
broker._compression_worker_key["compression-a"] = "default"
destroy_profile_result = broker.handle({"action": "destroy_profile", "profile": "default"})
assert destroy_profile_result == {"profile": "default", "destroyed": 2}
assert profile_worker.stopped
assert "default" not in broker._workers
assert broker._run_profile == {}
assert broker._run_worker_key == {}
assert broker._running_run_profile == {}
assert broker._running_run_worker_key == {}
assert broker._session_profile == {}
assert broker._session_worker_key == {}
assert broker._approval_profile == {}
assert broker._approval_worker_key == {}
assert broker._compression_profile == {}
assert broker._compression_worker_key == {}
worker_a = FakeWorker(1)
worker_b = FakeWorker(3)
worker_a = FakeWorker(1, "default", "a")
worker_b = FakeWorker(3, "work", "b")
broker._workers["a"] = worker_a
broker._workers["b"] = worker_b
broker._run_profile["run-a"] = "a"
broker._running_run_profile["run-a"] = "a"
broker._session_profile["session-b"] = "b"
broker._run_profile["run-a"] = "default"
broker._run_worker_key["run-a"] = "a"
broker._running_run_profile["run-a"] = "default"
broker._running_run_worker_key["run-a"] = "a"
broker._session_profile["session-b"] = "work"
broker._session_worker_key["session-b"] = "b"
destroy_all_result = broker.handle({"action": "destroy_all"})
assert destroy_all_result == {"destroyed": 4}
@@ -339,8 +354,11 @@ assert worker_a.stopped
assert worker_b.stopped
assert broker._workers == {}
assert broker._run_profile == {}
assert broker._run_worker_key == {}
assert broker._running_run_profile == {}
assert broker._running_run_worker_key == {}
assert broker._session_profile == {}
assert broker._session_worker_key == {}
`)
})
@@ -372,6 +390,69 @@ assert resp["running_sessions_by_profile"] == {"default": 1}
`)
})
it('routes worker-keyed broker requests without stopping the worker on session destroy', () => {
runPython(String.raw`
${harness}
class RoutedWorker:
running = True
pid = 12345
endpoint = "ipc:///tmp/worker.sock"
last_used_at = 12.5
def __init__(self, profile, key):
self.profile = profile
self.key = key
self.requests = []
self.stopped = False
def request(self, req):
self.requests.append(req)
action = req.get("action")
if action == "chat":
return {"ok": True, "run_id": "run-compress", "session_id": req["session_id"], "status": "running"}
if action == "get_output":
return {"ok": True, "run_id": req["run_id"], "session_id": "compress-temp", "status": "complete", "done": True}
if action == "destroy":
return {"ok": True, "session_id": req["session_id"], "destroyed": True}
raise AssertionError(f"unexpected action: {action}")
def stop(self):
self.stopped = True
broker = bridge.BridgeBroker("ipc:///tmp/unused.sock")
worker = RoutedWorker("default", "default:compression:session-a")
broker._workers[worker.key] = worker
chat_resp = broker.handle({
"action": "chat",
"session_id": "compress-temp",
"profile": "default",
"worker_key": worker.key,
"message": "summarize",
})
assert chat_resp["run_id"] == "run-compress"
assert worker.requests[-1]["profile"] == "default"
assert "worker_key" not in worker.requests[-1]
broker.handle({"action": "get_output", "run_id": "run-compress"})
assert worker.requests[-1]["action"] == "get_output"
destroy_resp = broker.handle({
"action": "destroy",
"session_id": "compress-temp",
"profile": "default",
"worker_key": worker.key,
})
assert destroy_resp["destroyed"] is True
assert worker.requests[-1]["action"] == "destroy"
assert not worker.stopped
assert worker.key in broker._workers
assert "compress-temp" not in broker._session_profile
assert "compress-temp" not in broker._session_worker_key
`)
})
it('restores approval env and clears handlers when a run fails', () => {
runPython(String.raw`
${harness}
@@ -480,7 +561,7 @@ original_getpid = bridge.os.getpid
try:
bridge.subprocess.Popen = fake_popen
bridge.os.getpid = lambda: 4242
proc_worker = bridge.WorkerProcess("default", "ipc:///tmp/worker.sock", "/agent", "/home")
proc_worker = bridge.WorkerProcess("default:compression:session-a", "default", "ipc:///tmp/worker.sock", "/agent", "/home")
proc_worker._pipe_stderr = lambda: None
proc_worker._wait_ready = lambda: None
proc_worker.start()
+36
View File
@@ -153,6 +153,42 @@ describe('ChatContextCompressor', () => {
expect(saveCompressionSnapshotMock).toHaveBeenCalledWith('s1', 'compressed summary', 6, 10)
})
it('routes summarization through the provided worker key and destroys only the temporary agent session', async () => {
const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor')
const compressor = new ChatContextCompressor({
config: { headMessageCount: 0, tailMessageCount: 1, summaryBudget: 1000 },
})
const messages = [
{ role: 'user', content: 'old context' },
{ role: 'assistant', content: 'old response' },
{ role: 'user', content: 'tail' },
]
getCompressionSnapshotMock.mockReturnValue(null)
bridgeRequestMock.mockResolvedValue({
status: 'completed',
result: { final_response: 'compressed summary' },
})
await compressor.compress(messages, 'http://upstream', undefined, 's1', {
profile: 'default',
workerKey: 'default:compression:s1',
})
expect(bridgeRequestMock).toHaveBeenCalledWith(expect.objectContaining({
action: 'chat',
profile: 'default',
worker_key: 'default:compression:s1',
wait: true,
}), expect.any(Object))
const compressSessionId = bridgeRequestMock.mock.calls[0][0].session_id
expect(String(compressSessionId)).toMatch(/^compress_/)
expect(bridgeDestroyMock).toHaveBeenCalledWith(
compressSessionId,
'default',
'default:compression:s1',
)
})
it('does not pre-prune tool results before sending them to the summarizer', async () => {
const { ChatContextCompressor } = await import('../../packages/server/src/lib/context-compressor')
const compressor = new ChatContextCompressor({
@@ -127,8 +127,18 @@ describe('bridge run final context usage', () => {
buildSnapshotAwareHistoryMock.mockImplementation(async (_sessionId: string, _profile: string, history: any[]) => history)
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 11, outputTokens: 7 })
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 11, outputTokens: 7 })
getCachedBridgeContextOverheadMock.mockReturnValue(undefined)
contextTokensWithCachedOverheadMock.mockImplementation((_state: any, messageTokens: number) => messageTokens)
getCachedBridgeContextOverheadMock.mockImplementation((state: any) => {
const fixed = state?.bridgeContext?.fixedContextTokens
return typeof fixed === 'number' ? fixed : undefined
})
contextTokensWithCachedOverheadMock.mockImplementation((state: any, messageTokens: number) => {
const fixed = state?.bridgeContext?.fixedContextTokens
return typeof fixed === 'number' ? fixed + messageTokens : messageTokens
})
updateMessageContextTokenUsageMock.mockImplementation((sid: string, state: any, emit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => {
const contextTokens = contextTokensWithCachedOverheadMock(state, messageTokens)
return updateContextTokenUsageMock(sid, state, emit, contextTokens, usage)
})
})
it('refreshes full context tokens when a bridge run completes', async () => {
@@ -141,6 +151,7 @@ describe('bridge run final context usage', () => {
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
contextEstimate: vi.fn().mockResolvedValue({
token_count: 12345,
fixed_context_tokens: 12327,
message_count: 2,
tool_count: 4,
system_prompt_chars: 13,
@@ -165,10 +176,7 @@ describe('bridge run final context usage', () => {
expect(bridge.contextEstimate).toHaveBeenCalledWith(
'session-1',
[
{ role: 'user', content: 'hello' },
{ role: 'assistant', content: 'done' },
],
[],
expect.stringContaining('[Current Hermes profile: default]'),
'default',
{ model: 'gpt-test', provider: 'openai' },
@@ -326,14 +334,22 @@ describe('bridge run final context usage', () => {
const nsp = makeNamespace(emit)
const socket = makeSocket()
const state = makeState()
state.bridgeContext = { fixedContextTokens: 20_000 }
const sessionMap = new Map([['session-1', state]])
getCachedBridgeContextOverheadMock.mockReturnValue(20_000)
updateMessageContextTokenUsageMock.mockImplementation((sid: string, targetState: any, targetEmit: any, messageTokens: number, usage?: { inputTokens: number; outputTokens: number }) => updateContextTokenUsageMock(sid, targetState, targetEmit, 20_000 + messageTokens, usage))
const bridge = {
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
contextEstimate: vi.fn(),
streamOutput: vi.fn(async function* () {
yield {
run_id: 'run-1',
done: false,
status: 'running',
events: [{
event: 'bridge.context.ready',
fixed_context_tokens: 20_000,
system_prompt_tokens: 3_000,
tool_tokens: 17_000,
}],
}
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
}),
} as any
@@ -365,6 +381,80 @@ describe('bridge run final context usage', () => {
}))
})
it('keeps bridge context ready updates on the snapshot-aware token baseline', async () => {
const emit = vi.fn()
const nsp = makeNamespace(emit)
const socket = makeSocket()
const state = makeState()
const sessionMap = new Map([['session-1', state]])
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 28_000, outputTokens: 0 })
buildDbHistoryMock.mockResolvedValue([
{ role: 'user', content: 'very large old context' },
{ role: 'assistant', content: 'large old response' },
{ role: 'user', content: 'hello' },
])
buildSnapshotAwareHistoryMock.mockResolvedValue([
{ role: 'user', content: '[Previous context summary]\n\nsmall summary' },
{ role: 'user', content: 'hello' },
])
estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => {
if (messages?.[0]?.content?.includes('small summary')) {
return { inputTokens: 9_000, outputTokens: 0 }
}
return { inputTokens: 28_000, outputTokens: 0 }
})
const bridge = {
chat: vi.fn().mockResolvedValue({ run_id: 'run-1', status: 'started' }),
contextEstimate: vi.fn(),
streamOutput: vi.fn(async function* () {
yield {
run_id: 'run-1',
done: false,
status: 'running',
events: [{
event: 'bridge.context.ready',
fixed_context_tokens: 10_000,
system_prompt_tokens: 2_000,
tool_tokens: 8_000,
}],
}
yield { run_id: 'run-1', done: true, status: 'completed', output: 'done' }
}),
} as any
const { handleBridgeRun } = await import('../../packages/server/src/services/hermes/run-chat/handle-bridge-run')
await handleBridgeRun(
nsp,
socket,
{ input: 'hello', session_id: 'session-1' },
'default',
sessionMap,
bridge,
false,
vi.fn(),
vi.fn(),
)
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
'session-1',
state,
expect.any(Function),
9_000,
{ inputTokens: 28_000, outputTokens: 0 },
)
expect(updateMessageContextTokenUsageMock).not.toHaveBeenCalledWith(
'session-1',
state,
expect.any(Function),
28_000,
{ inputTokens: 28_000, outputTokens: 0 },
)
expect(state.contextTokens).toBe(19_000)
expect(emit).toHaveBeenCalledWith('run.completed', expect.objectContaining({
contextTokens: 19_000,
}))
})
it('persists pending tool marker text before a bridge run completes', async () => {
const emit = vi.fn()
const nsp = makeNamespace(emit)
@@ -502,6 +592,7 @@ describe('bridge run final context usage', () => {
chat: vi.fn().mockRejectedValue(new Error('bridge timeout')),
contextEstimate: vi.fn().mockResolvedValue({
token_count: 54321,
fixed_context_tokens: 54303,
message_count: 1,
tool_count: 4,
system_prompt_chars: 13,
+110 -5
View File
@@ -175,7 +175,7 @@ describe('run chat compression trigger', () => {
)
})
it('uses full context estimates for compression threshold decisions', async () => {
it('uses local context estimates for compression threshold decisions', async () => {
const messages = Array.from({ length: 10 }, (_, index) => ({
id: index + 1,
session_id: 'session-1',
@@ -191,7 +191,7 @@ describe('run chat compression trigger', () => {
getSessionDetailMock.mockReturnValue({ messages })
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 1_000, outputTokens: 0 })
compressorCompressMock.mockResolvedValue({
messages: [{ role: 'user', content: 'compressed by full context estimate' }],
messages: [{ role: 'user', content: 'compressed by local context estimate' }],
meta: {
compressed: true,
llmCompressed: true,
@@ -215,7 +215,7 @@ describe('run chat compression trigger', () => {
vi.fn(async () => 120_000),
)
expect(history).toEqual([{ role: 'user', content: 'compressed by full context estimate' }])
expect(history).toEqual([{ role: 'user', content: 'compressed by local context estimate' }])
expect(compressorCompressMock).toHaveBeenCalledTimes(1)
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
'session-1',
@@ -226,7 +226,7 @@ describe('run chat compression trigger', () => {
)
})
it('emits full context token usage when the full estimate is under threshold', async () => {
it('emits local context token usage when the local estimate is under threshold', async () => {
const messages = Array.from({ length: 10 }, (_, index) => ({
id: index + 1,
session_id: 'session-1',
@@ -257,7 +257,10 @@ describe('run chat compression trigger', () => {
)
expect(history).toHaveLength(9)
expect(contextTokenEstimator).toHaveBeenCalledWith(expect.arrayContaining([{ role: 'user', content: 'message 0' }]))
expect(contextTokenEstimator).toHaveBeenCalledWith(
expect.arrayContaining([{ role: 'user', content: 'message 0' }]),
1_900,
)
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
event: 'usage.updated',
session_id: 'session-1',
@@ -268,6 +271,108 @@ describe('run chat compression trigger', () => {
expect(compressorCompressMock).not.toHaveBeenCalled()
})
it('includes current input tokens when estimating snapshot-aware context', async () => {
const messages = Array.from({ length: 10 }, (_, index) => ({
id: index + 1,
session_id: 'session-1',
role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant',
content: `message ${index}`,
timestamp: index + 1,
tool_call_id: null,
tool_calls: null,
tool_name: null,
finish_reason: null,
reasoning_content: null,
}))
getSessionDetailMock.mockReturnValue({ messages })
getCompressionSnapshotMock.mockReturnValue({
summary: 'previous summary',
lastMessageIndex: 4,
messageCountAtTime: 5,
})
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 10, outputTokens: 0 })
estimateUsageTokensFromMessagesMock.mockReturnValue({ inputTokens: 1_000, outputTokens: 0 })
const emit = vi.fn()
const contextTokenEstimator = vi.fn(async (_messages, messageTokens: number) => 20_000 + messageTokens)
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
await buildCompressedHistory(
'session-1',
'default',
'http://upstream',
undefined,
emit,
new Map(),
{},
contextTokenEstimator,
700,
)
expect(contextTokenEstimator).toHaveBeenCalledWith(expect.any(Array), 1_700)
expect(emit).toHaveBeenCalledWith('usage.updated', expect.objectContaining({
contextTokens: 21_700,
}))
expect(compressorCompressMock).not.toHaveBeenCalled()
})
it('keeps current input tokens in the compression completed context total', async () => {
const messages = Array.from({ length: 10 }, (_, index) => ({
id: index + 1,
session_id: 'session-1',
role: index === 9 ? 'user' : index % 2 === 0 ? 'user' : 'assistant',
content: `message ${index}`,
timestamp: index + 1,
tool_call_id: null,
tool_calls: null,
tool_name: null,
finish_reason: null,
reasoning_content: null,
}))
getSessionDetailMock.mockReturnValue({ messages })
calcAndUpdateUsageMock.mockResolvedValue({ inputTokens: 100, outputTokens: 0 })
estimateUsageTokensFromMessagesMock.mockImplementation((items: any[]) => {
if (items?.[0]?.content === 'compressed result') return { inputTokens: 1_000, outputTokens: 0 }
return { inputTokens: 100, outputTokens: 0 }
})
compressorCompressMock.mockResolvedValue({
messages: [{ role: 'user', content: 'compressed result' }],
meta: {
compressed: true,
llmCompressed: true,
totalMessages: 9,
summaryTokenEstimate: 1,
verbatimCount: 0,
compressedStartIndex: 0,
},
})
const emit = vi.fn()
const { buildCompressedHistory } = await import('../../packages/server/src/services/hermes/run-chat/compression')
await buildCompressedHistory(
'session-1',
'default',
'http://upstream',
undefined,
emit,
new Map(),
{},
vi.fn(async () => 120_000),
700,
)
expect(updateMessageContextTokenUsageMock).toHaveBeenCalledWith(
'session-1',
expect.any(Object),
emit,
1_700,
{ inputTokens: 100, outputTokens: 0 },
)
expect(emit).toHaveBeenCalledWith('compression.completed', expect.objectContaining({
afterTokens: 1_700,
contextTokens: 1_700,
}))
})
it('throws when fixed prompt and tool schemas exceed threshold before any history exists', async () => {
getSessionDetailMock.mockReturnValue({ messages: [] })
const emit = vi.fn()
+129
View File
@@ -0,0 +1,129 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
const getSessionMock = vi.fn()
const getSessionDetailPaginatedMock = vi.fn()
const getCompressionSnapshotMock = vi.fn()
const estimateUsageTokensFromMessagesMock = vi.fn()
const buildDbHistoryMock = vi.fn()
const buildSnapshotAwareHistoryMock = vi.fn()
vi.mock('../../packages/server/src/db/hermes/session-store', () => ({
getSession: getSessionMock,
createSession: vi.fn(),
addMessage: vi.fn(),
updateSessionStats: vi.fn(),
getSessionDetailPaginated: getSessionDetailPaginatedMock,
}))
vi.mock('../../packages/server/src/db/hermes/usage-store', () => ({
updateUsage: vi.fn(),
}))
vi.mock('../../packages/server/src/db/hermes/compression-snapshot', () => ({
getCompressionSnapshot: getCompressionSnapshotMock,
}))
vi.mock('../../packages/server/src/lib/context-compressor', () => ({
SUMMARY_PREFIX: '[Previous context summary]',
countTokens: vi.fn(() => 0),
}))
vi.mock('../../packages/server/src/services/logger', () => ({
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() },
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/compression', () => ({
buildCompressedHistory: vi.fn(),
buildDbHistory: buildDbHistoryMock,
buildSnapshotAwareHistory: buildSnapshotAwareHistoryMock,
getOrCreateSession: vi.fn(),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/usage', () => ({
calcAndUpdateUsage: vi.fn(),
estimateUsageTokensFromMessages: estimateUsageTokensFromMessagesMock,
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/message-format', () => ({
convertHistoryFormat: vi.fn((messages: any[]) => messages),
handleMessage: vi.fn((messages: any[]) => messages),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/content-blocks', () => ({
contentBlocksToString: vi.fn((value: any) => String(value || '')),
extractTextForPreview: vi.fn((value: any) => String(value || '')),
isContentBlockArray: vi.fn(() => false),
convertContentBlocks: vi.fn(),
}))
vi.mock('../../packages/server/src/lib/llm-prompt', () => ({
getSystemPrompt: vi.fn(() => 'system prompt'),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/sse-utils', () => ({
readSseFrames: vi.fn(),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/response-utils', () => ({
extractResponseText: vi.fn(),
}))
vi.mock('../../packages/server/src/services/hermes/run-chat/response-stream', () => ({
applyResponseStreamEvent: vi.fn(),
flushResponseRunToDb: vi.fn(),
}))
describe('loadSessionStateFromDb', () => {
beforeEach(() => {
vi.clearAllMocks()
getSessionMock.mockReturnValue({
id: 'session-1',
profile: 'default',
model: 'gpt-test',
provider: 'openai',
})
getSessionDetailPaginatedMock.mockReturnValue({
messages: [
{ role: 'user', content: 'old large context' },
{ role: 'assistant', content: 'old large answer' },
{ role: 'user', content: 'new tail' },
],
})
getCompressionSnapshotMock.mockReturnValue({
summary: 'small summary',
lastMessageIndex: 0,
messageCountAtTime: 1,
})
buildDbHistoryMock.mockResolvedValue([
{ role: 'user', content: 'old large context' },
{ role: 'assistant', content: 'old large answer' },
{ role: 'user', content: 'new tail' },
])
buildSnapshotAwareHistoryMock.mockResolvedValue([
{ role: 'user', content: '[Previous context summary]\n\nsmall summary' },
{ role: 'user', content: 'new tail' },
])
estimateUsageTokensFromMessagesMock.mockImplementation((messages: any[]) => {
if (messages?.[0]?.content?.includes('small summary')) {
return { inputTokens: 9_000, outputTokens: 0 }
}
return { inputTokens: 28_000, outputTokens: 0 }
})
})
it('hydrates contextTokens from the same snapshot-aware history used for bridge runs', async () => {
const { loadSessionStateFromDb } = await import('../../packages/server/src/services/hermes/run-chat/handle-api-run')
const state = await loadSessionStateFromDb('session-1', new Map())
expect(buildSnapshotAwareHistoryMock).toHaveBeenCalledWith(
'session-1',
'default',
expect.any(Array),
{ model: 'gpt-test', provider: 'openai' },
)
expect(state.inputTokens).toBe(28_000)
expect(state.outputTokens).toBe(0)
expect(state.contextTokens).toBe(9_000)
})
})