fix context token resume (#1039)
This commit is contained in:
@@ -79,6 +79,7 @@ export interface SummarizerOptions {
|
||||
profile?: string
|
||||
model?: string | null
|
||||
provider?: string | null
|
||||
workerKey?: string
|
||||
}
|
||||
|
||||
// ─── Token counting ─────────────────────────────────────
|
||||
@@ -454,6 +455,7 @@ export async function callSummarizer(
|
||||
|
||||
const bridge = new AgentBridgeClient({ timeoutMs: timeoutMs + 15_000 })
|
||||
const sessionId = `compress_${Date.now().toString(36)}_${randomUUID().replace(/-/g, '').slice(0, 12)}`
|
||||
const workerKey = options.workerKey || `${profile}:compression:${sessionId}`
|
||||
|
||||
try {
|
||||
const result = await bridge.request<AgentBridgeRunResult>({
|
||||
@@ -462,6 +464,7 @@ export async function callSummarizer(
|
||||
message: prompt,
|
||||
conversation_history: convHistory,
|
||||
profile,
|
||||
worker_key: workerKey,
|
||||
source: 'api_server',
|
||||
wait: true,
|
||||
timeout: Math.ceil(timeoutMs / 1000),
|
||||
@@ -482,7 +485,7 @@ export async function callSummarizer(
|
||||
if (!output) throw new Error('Empty summarization response')
|
||||
return output
|
||||
} finally {
|
||||
await bridge.destroy(sessionId, profile).catch(() => undefined)
|
||||
await bridge.destroy(sessionId, profile, workerKey).catch(() => undefined)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ export class AgentBridgeClient {
|
||||
private summarizePayload(payload: Record<string, unknown>): Record<string, unknown> {
|
||||
const action = String(payload.action || '')
|
||||
const summary: Record<string, unknown> = { action }
|
||||
for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile']) {
|
||||
for (const key of ['session_id', 'run_id', 'request_id', 'approval_id', 'profile', 'worker_key']) {
|
||||
if (payload[key] != null) summary[key] = payload[key]
|
||||
}
|
||||
if (Array.isArray(payload.conversation_history)) summary.conversation_history_count = payload.conversation_history.length
|
||||
@@ -569,11 +569,12 @@ export class AgentBridgeClient {
|
||||
})
|
||||
}
|
||||
|
||||
destroy(sessionId: string, profile?: string): Promise<AgentBridgeResponse> {
|
||||
destroy(sessionId: string, profile?: string, workerKey?: string): Promise<AgentBridgeResponse> {
|
||||
return this.request({
|
||||
action: 'destroy',
|
||||
session_id: sessionId,
|
||||
...(profile ? { profile } : {}),
|
||||
...(workerKey ? { worker_key: workerKey } : {}),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2200,7 +2200,8 @@ class WorkerProcess:
|
||||
STARTUP_TIMEOUT_SECONDS = 120
|
||||
REQUEST_TIMEOUT_SECONDS = 120
|
||||
|
||||
def __init__(self, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None:
|
||||
def __init__(self, key: str, profile: str, endpoint: str, agent_root: str | None, hermes_home: str | None) -> None:
|
||||
self.key = key or profile or "default"
|
||||
self.profile = profile or "default"
|
||||
self.endpoint = endpoint
|
||||
self.agent_root = agent_root
|
||||
@@ -2263,14 +2264,14 @@ class WorkerProcess:
|
||||
for line in proc.stderr:
|
||||
text = line.rstrip()
|
||||
if text:
|
||||
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
|
||||
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
|
||||
|
||||
threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.profile}").start()
|
||||
threading.Thread(target=run, daemon=True, name=f"hermes-bridge-worker-stderr-{self.key}").start()
|
||||
|
||||
def _wait_ready(self) -> None:
|
||||
proc = self.process
|
||||
if proc is None or proc.stdout is None:
|
||||
raise RuntimeError(f"profile worker {self.profile} did not start")
|
||||
raise RuntimeError(f"profile worker {self.key} did not start")
|
||||
lines: queue.Queue[str | None] = queue.Queue()
|
||||
ready_event = threading.Event()
|
||||
|
||||
@@ -2281,17 +2282,17 @@ class WorkerProcess:
|
||||
if ready_event.is_set():
|
||||
text = line.rstrip()
|
||||
if text:
|
||||
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
|
||||
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
|
||||
else:
|
||||
lines.put(line)
|
||||
finally:
|
||||
lines.put(None)
|
||||
|
||||
threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.profile}").start()
|
||||
threading.Thread(target=read_stdout, daemon=True, name=f"hermes-bridge-worker-stdout-{self.key}").start()
|
||||
deadline = time.time() + self.STARTUP_TIMEOUT_SECONDS
|
||||
while time.time() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(f"profile worker {self.profile} exited before ready")
|
||||
raise RuntimeError(f"profile worker {self.key} exited before ready")
|
||||
try:
|
||||
line = lines.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
@@ -2301,7 +2302,7 @@ class WorkerProcess:
|
||||
continue
|
||||
text = line.strip()
|
||||
if text:
|
||||
print(f"[hermes-bridge-worker:{self.profile}] {text}", file=sys.stderr, flush=True)
|
||||
print(f"[hermes-bridge-worker:{self.key}] {text}", file=sys.stderr, flush=True)
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if data.get("event") == "ready":
|
||||
@@ -2310,7 +2311,7 @@ class WorkerProcess:
|
||||
except Exception:
|
||||
pass
|
||||
self.stop()
|
||||
raise RuntimeError(f"profile worker {self.profile} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s")
|
||||
raise RuntimeError(f"profile worker {self.key} did not become ready within {self.STARTUP_TIMEOUT_SECONDS}s")
|
||||
|
||||
def stop(self) -> None:
|
||||
with self._lock:
|
||||
@@ -2337,8 +2338,8 @@ class WorkerProcess:
|
||||
return _send_bridge_request(self.endpoint, req, self.REQUEST_TIMEOUT_SECONDS)
|
||||
|
||||
|
||||
def _worker_endpoint(profile: str) -> str:
|
||||
safe = hashlib.sha256(profile.encode("utf-8")).hexdigest()[:16]
|
||||
def _worker_endpoint(key: str) -> str:
|
||||
safe = hashlib.sha256(key.encode("utf-8")).hexdigest()[:16]
|
||||
if os.name == "nt":
|
||||
port_base = int(os.environ.get("HERMES_AGENT_BRIDGE_WORKER_PORT_BASE", "18780"))
|
||||
return f"tcp://127.0.0.1:{port_base + int(safe[:4], 16) % 1000}"
|
||||
@@ -2533,11 +2534,17 @@ class BridgeBroker:
|
||||
self.hermes_home = hermes_home
|
||||
self._workers: dict[str, WorkerProcess] = {}
|
||||
self._run_profile: dict[str, str] = {}
|
||||
self._run_worker_key: dict[str, str] = {}
|
||||
self._running_run_profile: dict[str, str] = {}
|
||||
self._running_run_worker_key: dict[str, str] = {}
|
||||
self._session_profile: dict[str, str] = {}
|
||||
self._session_worker_key: dict[str, str] = {}
|
||||
self._approval_profile: dict[str, str] = {}
|
||||
self._approval_worker_key: dict[str, str] = {}
|
||||
self._clarify_profile: dict[str, str] = {}
|
||||
self._clarify_worker_key: dict[str, str] = {}
|
||||
self._compression_profile: dict[str, str] = {}
|
||||
self._compression_worker_key: dict[str, str] = {}
|
||||
self._lock = threading.RLock()
|
||||
self._stop = threading.Event()
|
||||
self._last_gc = time.time()
|
||||
@@ -2546,58 +2553,73 @@ class BridgeBroker:
|
||||
profile = str(value or "").strip()
|
||||
return profile or "default"
|
||||
|
||||
def _worker_for_profile(self, profile: str) -> WorkerProcess:
|
||||
def _normalize_worker_key(self, profile: str, value: Any = None) -> str:
|
||||
worker_key = str(value or "").strip()
|
||||
return worker_key or profile
|
||||
|
||||
def _worker_for_profile(self, profile: str, worker_key: str | None = None) -> WorkerProcess:
|
||||
profile = self._normalize_profile(profile)
|
||||
key = self._normalize_worker_key(profile, worker_key)
|
||||
with self._lock:
|
||||
worker = self._workers.get(profile)
|
||||
worker = self._workers.get(key)
|
||||
if worker is None:
|
||||
worker = WorkerProcess(profile, _worker_endpoint(profile), self.agent_root, self.hermes_home)
|
||||
self._workers[profile] = worker
|
||||
worker = WorkerProcess(key, profile, _worker_endpoint(key), self.agent_root, self.hermes_home)
|
||||
self._workers[key] = worker
|
||||
return worker
|
||||
|
||||
def _profile_for_run(self, run_id: str) -> str:
|
||||
def _route_for_run(self, run_id: str) -> tuple[str, str | None]:
|
||||
with self._lock:
|
||||
profile = self._run_profile.get(run_id)
|
||||
worker_key = self._run_worker_key.get(run_id)
|
||||
if not profile:
|
||||
raise KeyError(f"unknown run: {run_id}")
|
||||
return profile
|
||||
return profile, worker_key
|
||||
|
||||
def _profile_for_session(self, session_id: str, fallback_profile: Any = None) -> str:
|
||||
def _route_for_session(self, session_id: str, fallback_profile: Any = None, worker_key: Any = None) -> tuple[str, str | None]:
|
||||
with self._lock:
|
||||
profile = self._session_profile.get(session_id)
|
||||
stored_worker_key = self._session_worker_key.get(session_id)
|
||||
if not profile:
|
||||
fallback = self._normalize_profile(fallback_profile)
|
||||
if fallback_profile is not None and fallback:
|
||||
return fallback
|
||||
return fallback, self._normalize_worker_key(fallback, worker_key)
|
||||
raise KeyError(f"unknown session: {session_id}")
|
||||
return profile
|
||||
return profile, self._normalize_worker_key(profile, worker_key) if worker_key is not None else stored_worker_key
|
||||
|
||||
def _record_response_routes(self, profile: str, resp: dict[str, Any]) -> None:
|
||||
def _record_response_routes(self, profile: str, worker_key: str, resp: dict[str, Any]) -> None:
|
||||
run_id = str(resp.get("run_id") or "")
|
||||
session_id = str(resp.get("session_id") or "")
|
||||
with self._lock:
|
||||
if run_id:
|
||||
self._run_profile[run_id] = profile
|
||||
self._run_worker_key[run_id] = worker_key
|
||||
if resp.get("status") == "running":
|
||||
self._running_run_profile[run_id] = profile
|
||||
self._running_run_worker_key[run_id] = worker_key
|
||||
else:
|
||||
self._running_run_profile.pop(run_id, None)
|
||||
self._running_run_worker_key.pop(run_id, None)
|
||||
if session_id:
|
||||
self._session_profile[session_id] = profile
|
||||
self._session_worker_key[session_id] = worker_key
|
||||
for event in resp.get("events") or []:
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
approval_id = str(event.get("approval_id") or "")
|
||||
if approval_id:
|
||||
self._approval_profile[approval_id] = profile
|
||||
self._approval_worker_key[approval_id] = worker_key
|
||||
clarify_id = str(event.get("clarify_id") or "")
|
||||
if clarify_id:
|
||||
self._clarify_profile[clarify_id] = profile
|
||||
self._clarify_worker_key[clarify_id] = worker_key
|
||||
request_id = str(event.get("request_id") or "")
|
||||
if event.get("event") == "bridge.compression.requested" and request_id:
|
||||
self._compression_profile[request_id] = profile
|
||||
self._compression_worker_key[request_id] = worker_key
|
||||
if event.get("event") in {"bridge.compression.completed", "bridge.compression.failed"} and request_id:
|
||||
self._compression_profile.pop(request_id, None)
|
||||
self._compression_worker_key.pop(request_id, None)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
@@ -2605,20 +2627,29 @@ class BridgeBroker:
|
||||
workers = list(self._workers.values())
|
||||
self._workers.clear()
|
||||
self._run_profile.clear()
|
||||
self._run_worker_key.clear()
|
||||
self._running_run_profile.clear()
|
||||
self._running_run_worker_key.clear()
|
||||
self._session_profile.clear()
|
||||
self._session_worker_key.clear()
|
||||
self._approval_profile.clear()
|
||||
self._approval_worker_key.clear()
|
||||
self._clarify_profile.clear()
|
||||
self._clarify_worker_key.clear()
|
||||
self._compression_profile.clear()
|
||||
self._compression_worker_key.clear()
|
||||
for worker in workers:
|
||||
worker.stop()
|
||||
|
||||
def _forward(self, profile: str, req: dict[str, Any]) -> dict[str, Any]:
|
||||
worker = self._worker_for_profile(profile)
|
||||
def _forward(self, profile: str, req: dict[str, Any], worker_key: str | None = None) -> dict[str, Any]:
|
||||
profile = self._normalize_profile(profile)
|
||||
key = self._normalize_worker_key(profile, worker_key)
|
||||
worker = self._worker_for_profile(profile, key)
|
||||
forwarded = dict(req)
|
||||
forwarded["profile"] = profile
|
||||
forwarded.pop("worker_key", None)
|
||||
resp = worker.request(forwarded)
|
||||
self._record_response_routes(profile, resp)
|
||||
self._record_response_routes(profile, key, resp)
|
||||
return resp
|
||||
|
||||
def handle(self, req: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -2629,15 +2660,16 @@ class BridgeBroker:
|
||||
if action == "ping":
|
||||
with self._lock:
|
||||
worker_details = {
|
||||
profile: {
|
||||
key: {
|
||||
"running": worker.running,
|
||||
"pid": worker.pid,
|
||||
"endpoint": worker.endpoint,
|
||||
"profile": getattr(worker, "profile", key),
|
||||
"last_used_at": worker.last_used_at,
|
||||
}
|
||||
for profile, worker in self._workers.items()
|
||||
for key, worker in self._workers.items()
|
||||
}
|
||||
workers = {profile: details["running"] for profile, details in worker_details.items()}
|
||||
workers = {key: details["running"] for key, details in worker_details.items()}
|
||||
sessions_by_profile: dict[str, int] = {}
|
||||
for profile in self._session_profile.values():
|
||||
sessions_by_profile[profile] = sessions_by_profile.get(profile, 0) + 1
|
||||
@@ -2664,29 +2696,32 @@ class BridgeBroker:
|
||||
|
||||
if action == "worker_ping":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
resp = self._forward(profile, {"action": "ping"})
|
||||
worker_key = self._normalize_worker_key(profile, req.get("worker_key"))
|
||||
resp = self._forward(profile, {"action": "ping"}, worker_key)
|
||||
resp["worker_profile"] = profile
|
||||
resp["worker_key"] = worker_key
|
||||
return resp
|
||||
|
||||
if action == "chat":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
|
||||
|
||||
if action == "context_estimate":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, self._normalize_worker_key(profile, req.get("worker_key")))
|
||||
|
||||
if action in {"get_result", "get_output"}:
|
||||
profile = self._profile_for_run(str(req.get("run_id") or ""))
|
||||
return self._forward(profile, req)
|
||||
profile, worker_key = self._route_for_run(str(req.get("run_id") or ""))
|
||||
return self._forward(profile, req, worker_key)
|
||||
|
||||
if action in {"interrupt", "steer", "command", "goal_evaluate", "goal_pause", "status", "get_history", "destroy"}:
|
||||
session_id = str(req.get("session_id") or "")
|
||||
profile = self._profile_for_session(session_id, req.get("profile"))
|
||||
resp = self._forward(profile, req)
|
||||
profile, worker_key = self._route_for_session(session_id, req.get("profile"), req.get("worker_key") if "worker_key" in req else None)
|
||||
resp = self._forward(profile, req, worker_key)
|
||||
if action == "destroy":
|
||||
with self._lock:
|
||||
self._session_profile.pop(session_id, None)
|
||||
self._session_worker_key.pop(session_id, None)
|
||||
return resp
|
||||
|
||||
if action == "approval_respond":
|
||||
@@ -2695,9 +2730,10 @@ class BridgeBroker:
|
||||
raise ValueError("approval_id is required")
|
||||
with self._lock:
|
||||
profile = self._approval_profile.get(approval_id)
|
||||
worker_key = self._approval_worker_key.get(approval_id)
|
||||
if not profile:
|
||||
raise KeyError(f"unknown approval request: {approval_id}")
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, worker_key)
|
||||
|
||||
if action == "clarify_respond":
|
||||
clarify_id = str(req.get("clarify_id") or "").strip()
|
||||
@@ -2705,9 +2741,10 @@ class BridgeBroker:
|
||||
raise ValueError("clarify_id is required")
|
||||
with self._lock:
|
||||
profile = self._clarify_profile.get(clarify_id)
|
||||
worker_key = self._clarify_worker_key.get(clarify_id)
|
||||
if not profile:
|
||||
raise KeyError(f"unknown clarify request: {clarify_id}")
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, worker_key)
|
||||
|
||||
if action == "compression_respond":
|
||||
request_id = str(req.get("request_id") or "").strip()
|
||||
@@ -2715,20 +2752,27 @@ class BridgeBroker:
|
||||
raise ValueError("request_id is required")
|
||||
with self._lock:
|
||||
profile = self._compression_profile.get(request_id)
|
||||
worker_key = self._compression_worker_key.get(request_id)
|
||||
if not profile:
|
||||
raise KeyError(f"unknown compression request: {request_id}")
|
||||
return self._forward(profile, req)
|
||||
return self._forward(profile, req, worker_key)
|
||||
|
||||
if action == "destroy_all":
|
||||
with self._lock:
|
||||
workers = list(self._workers.values())
|
||||
self._workers.clear()
|
||||
self._run_profile.clear()
|
||||
self._run_worker_key.clear()
|
||||
self._running_run_profile.clear()
|
||||
self._running_run_worker_key.clear()
|
||||
self._session_profile.clear()
|
||||
self._session_worker_key.clear()
|
||||
self._approval_profile.clear()
|
||||
self._approval_worker_key.clear()
|
||||
self._clarify_profile.clear()
|
||||
self._clarify_worker_key.clear()
|
||||
self._compression_profile.clear()
|
||||
self._compression_worker_key.clear()
|
||||
destroyed = 0
|
||||
for worker in workers:
|
||||
try:
|
||||
@@ -2744,40 +2788,56 @@ class BridgeBroker:
|
||||
if action == "destroy_profile":
|
||||
profile = self._normalize_profile(req.get("profile"))
|
||||
with self._lock:
|
||||
worker = self._workers.pop(profile, None)
|
||||
workers = [
|
||||
worker
|
||||
for key, worker in list(self._workers.items())
|
||||
if getattr(worker, "profile", key) == profile
|
||||
]
|
||||
for worker in workers:
|
||||
self._workers.pop(worker.key, None)
|
||||
self._run_profile = {key: value for key, value in self._run_profile.items() if value != profile}
|
||||
self._run_worker_key = {key: value for key, value in self._run_worker_key.items() if key in self._run_profile}
|
||||
self._running_run_profile = {key: value for key, value in self._running_run_profile.items() if value != profile}
|
||||
self._running_run_worker_key = {key: value for key, value in self._running_run_worker_key.items() if key in self._running_run_profile}
|
||||
self._session_profile = {key: value for key, value in self._session_profile.items() if value != profile}
|
||||
self._session_worker_key = {key: value for key, value in self._session_worker_key.items() if key in self._session_profile}
|
||||
self._approval_profile = {key: value for key, value in self._approval_profile.items() if value != profile}
|
||||
self._approval_worker_key = {key: value for key, value in self._approval_worker_key.items() if key in self._approval_profile}
|
||||
self._clarify_profile = {key: value for key, value in self._clarify_profile.items() if value != profile}
|
||||
self._clarify_worker_key = {key: value for key, value in self._clarify_worker_key.items() if key in self._clarify_profile}
|
||||
self._compression_profile = {key: value for key, value in self._compression_profile.items() if value != profile}
|
||||
self._compression_worker_key = {key: value for key, value in self._compression_worker_key.items() if key in self._compression_profile}
|
||||
|
||||
if worker is None or not worker.running:
|
||||
if worker is not None:
|
||||
worker.stop()
|
||||
if not workers:
|
||||
return {"profile": profile, "destroyed": 0}
|
||||
|
||||
try:
|
||||
resp = worker.request({"action": "destroy_all"})
|
||||
destroyed = int(resp.get("destroyed") or 0)
|
||||
except Exception:
|
||||
destroyed = 0
|
||||
finally:
|
||||
worker.stop()
|
||||
destroyed = 0
|
||||
for worker in workers:
|
||||
if not worker.running:
|
||||
worker.stop()
|
||||
continue
|
||||
try:
|
||||
resp = worker.request({"action": "destroy_all"})
|
||||
destroyed += int(resp.get("destroyed") or 0)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
worker.stop()
|
||||
return {"profile": profile, "destroyed": destroyed}
|
||||
|
||||
if action == "list":
|
||||
sessions: list[Any] = []
|
||||
with self._lock:
|
||||
workers = list(self._workers.items())
|
||||
for profile, worker in workers:
|
||||
for key, worker in workers:
|
||||
if not worker.running:
|
||||
continue
|
||||
try:
|
||||
resp = worker.request({"action": "list"})
|
||||
for session in resp.get("sessions") or []:
|
||||
if isinstance(session, dict):
|
||||
session.setdefault("profile", profile)
|
||||
session.setdefault("profile", getattr(worker, "profile", key))
|
||||
session.setdefault("worker_key", key)
|
||||
sessions.append(session)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -2826,12 +2886,12 @@ class BridgeBroker:
|
||||
self._last_gc = now
|
||||
with self._lock:
|
||||
idle = [
|
||||
profile for profile, worker in self._workers.items()
|
||||
key for key, worker in self._workers.items()
|
||||
if worker.running and now - worker.last_used_at > self.IDLE_TIMEOUT_SECONDS
|
||||
]
|
||||
for profile in idle:
|
||||
for key in idle:
|
||||
with self._lock:
|
||||
worker = self._workers.pop(profile, None)
|
||||
worker = self._workers.pop(key, None)
|
||||
if worker:
|
||||
worker.stop()
|
||||
|
||||
|
||||
@@ -195,7 +195,8 @@ export async function buildCompressedHistory(
|
||||
emit: (event: string, payload: any) => void,
|
||||
sessionMap: Map<string, SessionState>,
|
||||
modelContext: { model?: string | null; provider?: string | null } = {},
|
||||
contextTokenEstimator?: (messages: ChatMessage[]) => Promise<number | null | undefined>,
|
||||
contextTokenEstimator?: (messages: ChatMessage[], messageTokens: number) => Promise<number | null | undefined>,
|
||||
currentInputTokens = 0,
|
||||
): Promise<ChatMessage[]> {
|
||||
try {
|
||||
let history = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||
@@ -213,14 +214,18 @@ export async function buildCompressedHistory(
|
||||
}
|
||||
const cState = getOrCreateSession(sessionMap, sessionId)
|
||||
const assembledTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
||||
const estimateFullContextTokens = async (messages: ChatMessage[], fallback: number) => {
|
||||
const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||
? Math.floor(currentInputTokens)
|
||||
: 0
|
||||
const estimateLocalContextTokens = async (messages: ChatMessage[], messageTokens: number) => {
|
||||
const localMessageTokens = Math.max(0, Math.floor(messageTokens))
|
||||
try {
|
||||
const estimate = await contextTokenEstimator?.(messages)
|
||||
const estimate = await contextTokenEstimator?.(messages, localMessageTokens)
|
||||
if (typeof estimate === 'number' && Number.isFinite(estimate) && estimate > 0) return Math.floor(estimate)
|
||||
} catch (err) {
|
||||
logger.warn(err, '[context-compress] session=%s: full context token estimate failed; using message-only estimate', sessionId)
|
||||
logger.warn(err, '[context-compress] session=%s: fixed context token estimate failed; using message-only estimate', sessionId)
|
||||
}
|
||||
return fallback
|
||||
return localMessageTokens
|
||||
}
|
||||
const emitContextUsage = (contextTokens: number) => {
|
||||
cState.contextTokens = contextTokens
|
||||
@@ -236,10 +241,10 @@ export async function buildCompressedHistory(
|
||||
let totalTokens = messageOnlyTotalTokens
|
||||
|
||||
if (history.length === 0) {
|
||||
totalTokens = await estimateFullContextTokens([], 0)
|
||||
totalTokens = await estimateLocalContextTokens([], Math.max(currentRunInputTokens, messageOnlyTotalTokens))
|
||||
if (totalTokens > triggerTokens) {
|
||||
throw new ContextWindowTooSmallError(
|
||||
`Context window is too small: system prompt and tool schemas already use ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, or disable some tools.`,
|
||||
`Context window is too small: fixed prompt/tool overhead plus the current input uses ~${totalTokens} tokens, exceeding compression threshold ${triggerTokens}. Increase model context length, raise compression.threshold, shorten the input, or disable some tools.`,
|
||||
)
|
||||
}
|
||||
if (totalTokens > 0) emitContextUsage(totalTokens)
|
||||
@@ -254,13 +259,15 @@ export async function buildCompressedHistory(
|
||||
sessionId, snapshot.lastMessageIndex, history.length)
|
||||
const staleHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
const staleUsage = estimateUsageTokensFromMessages(staleHistory)
|
||||
totalTokens = await estimateFullContextTokens(staleHistory, staleUsage.inputTokens + staleUsage.outputTokens)
|
||||
const staleMessageTokens = staleUsage.inputTokens + staleUsage.outputTokens
|
||||
const staleRunMessageTokens = Math.max(staleMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
|
||||
totalTokens = await estimateLocalContextTokens(staleHistory, staleRunMessageTokens)
|
||||
emitContextUsage(totalTokens)
|
||||
logger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
messages: staleHistory.length,
|
||||
messageOnlyTokens: staleUsage.inputTokens + staleUsage.outputTokens,
|
||||
messageOnlyTokens: staleRunMessageTokens,
|
||||
fullContextTokens: totalTokens,
|
||||
triggerTokens,
|
||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||
@@ -272,13 +279,15 @@ export async function buildCompressedHistory(
|
||||
const newMessages = history.slice(snapshot.lastMessageIndex + 1)
|
||||
const snapshotHistory = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
const snapshotUsage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||
totalTokens = await estimateFullContextTokens(snapshotHistory, snapshotUsage.inputTokens + snapshotUsage.outputTokens)
|
||||
const snapshotMessageTokens = snapshotUsage.inputTokens + snapshotUsage.outputTokens
|
||||
const snapshotRunMessageTokens = Math.max(snapshotMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
|
||||
totalTokens = await estimateLocalContextTokens(snapshotHistory, snapshotRunMessageTokens)
|
||||
emitContextUsage(totalTokens)
|
||||
logger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
messages: snapshotHistory.length,
|
||||
messageOnlyTokens: snapshotUsage.inputTokens + snapshotUsage.outputTokens,
|
||||
messageOnlyTokens: snapshotRunMessageTokens,
|
||||
fullContextTokens: totalTokens,
|
||||
triggerTokens,
|
||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||
@@ -289,22 +298,25 @@ export async function buildCompressedHistory(
|
||||
if (totalTokens <= triggerTokens) {
|
||||
history = snapshotHistory
|
||||
} else {
|
||||
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
||||
history = await compressHistory(history, newMessages, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
|
||||
}
|
||||
} else if (snapshot && staleSnapshot) {
|
||||
if (totalTokens <= triggerTokens) {
|
||||
history = buildSnapshotHistory(snapshot, history, compressionConfig.compressor) || history
|
||||
} else {
|
||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
|
||||
}
|
||||
} else {
|
||||
totalTokens = await estimateFullContextTokens(history, totalTokens)
|
||||
const historyUsage = estimateUsageTokensFromMessages(history)
|
||||
const historyMessageTokens = historyUsage.inputTokens + historyUsage.outputTokens
|
||||
const runMessageTokens = Math.max(historyMessageTokens + currentRunInputTokens, messageOnlyTotalTokens)
|
||||
totalTokens = await estimateLocalContextTokens(history, runMessageTokens)
|
||||
emitContextUsage(totalTokens)
|
||||
logger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
messages: history.length,
|
||||
messageOnlyTokens: messageOnlyTotalTokens,
|
||||
messageOnlyTokens: runMessageTokens,
|
||||
fullContextTokens: totalTokens,
|
||||
triggerTokens,
|
||||
decision: totalTokens > triggerTokens ? 'compress' : 'skip',
|
||||
@@ -318,7 +330,7 @@ export async function buildCompressedHistory(
|
||||
if (totalTokens <= triggerTokens) {
|
||||
logger.info('[context-compress] session=%s: %d messages, ~%d tokens — under threshold, skip', sessionId, history.length, totalTokens)
|
||||
} else {
|
||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor)
|
||||
history = await compressHistory(history, null, sessionId, upstream, apiKey, cState, totalTokens, emit, sessionMap, modelContext, compressionConfig.compressor, currentRunInputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,8 +354,12 @@ export async function compressHistory(
|
||||
sessionMap: Map<string, SessionState>,
|
||||
modelContext: { model?: string | null; provider?: string | null } = {},
|
||||
compressionConfig?: Partial<CompressorConfig>,
|
||||
currentInputTokens = 0,
|
||||
): Promise<ChatMessage[]> {
|
||||
const msgCount = newMessagesOnly ? newMessagesOnly.length : history.length
|
||||
const currentRunInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||
? Math.floor(currentInputTokens)
|
||||
: 0
|
||||
pushState(sessionMap, sessionId, 'compression.started', {
|
||||
event: 'compression.started', message_count: msgCount, token_count: totalTokens,
|
||||
})
|
||||
@@ -353,14 +369,22 @@ export async function compressHistory(
|
||||
|
||||
try {
|
||||
const session = getSession(sessionId)
|
||||
const summarizerProfile = session?.profile || 'default'
|
||||
const compressor = new ChatContextCompressor({ config: compressionConfig })
|
||||
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
|
||||
profile: session?.profile,
|
||||
profile: summarizerProfile,
|
||||
model: modelContext.model || session?.model,
|
||||
provider: modelContext.provider || session?.provider,
|
||||
workerKey: `${summarizerProfile}:compression:${sessionId}`,
|
||||
})
|
||||
const afterTokens = await calcAndUpdateUsage(sessionId, cState, emit)
|
||||
const compressedAfterTokens = afterTokens.inputTokens + afterTokens.outputTokens
|
||||
const resultUsage = estimateUsageTokensFromMessages(result.messages)
|
||||
const resultMessageTokens = resultUsage.inputTokens + resultUsage.outputTokens
|
||||
const compressedRunMessageTokens = Math.max(
|
||||
compressedAfterTokens,
|
||||
resultMessageTokens + currentRunInputTokens,
|
||||
)
|
||||
const compressedMeta: any = {
|
||||
event: 'compression.completed' as const,
|
||||
compressed: result.meta.compressed,
|
||||
@@ -368,15 +392,15 @@ export async function compressHistory(
|
||||
totalMessages: result.meta.totalMessages,
|
||||
resultMessages: result.messages.length,
|
||||
beforeTokens: totalTokens,
|
||||
afterTokens: compressedAfterTokens,
|
||||
afterTokens: compressedRunMessageTokens,
|
||||
summaryTokens: result.meta.summaryTokenEstimate,
|
||||
verbatimCount: result.meta.verbatimCount,
|
||||
compressedStartIndex: result.meta.compressedStartIndex,
|
||||
}
|
||||
replaceState(sessionMap, sessionId, 'compression.completed', compressedMeta)
|
||||
logger.info('[context-compress] AFTER session=%s: %d messages, ~%d tokens (was %d)',
|
||||
sessionId, result.messages.length, compressedAfterTokens, totalTokens)
|
||||
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedAfterTokens, afterTokens)
|
||||
sessionId, result.messages.length, compressedRunMessageTokens, totalTokens)
|
||||
const compressedContextTokens = updateMessageContextTokenUsage(sessionId, cState, emit, compressedRunMessageTokens, afterTokens)
|
||||
if (compressedContextTokens != null) {
|
||||
compressedMeta.contextTokens = compressedContextTokens
|
||||
}
|
||||
@@ -403,6 +427,7 @@ export async function compressHistory(
|
||||
resultMessages: msgCount,
|
||||
beforeTokens: totalTokens,
|
||||
afterTokens: totalTokens,
|
||||
contextTokens: totalTokens,
|
||||
summaryTokens: 0,
|
||||
verbatimCount: msgCount,
|
||||
compressedStartIndex: -1,
|
||||
@@ -458,10 +483,12 @@ export async function forceCompressBridgeHistory(
|
||||
}, '[chat-run-socket] bridge forced compression started')
|
||||
|
||||
const compressor = new ChatContextCompressor({ config: compressionConfig.compressor })
|
||||
const summarizerProfile = session?.profile || profile || 'default'
|
||||
const result = await compressor.compress(history, upstream, apiKey, sessionId, {
|
||||
profile: session?.profile || profile,
|
||||
profile: summarizerProfile,
|
||||
model: session?.model,
|
||||
provider: session?.provider,
|
||||
workerKey: `${summarizerProfile}:compression:${sessionId}`,
|
||||
})
|
||||
const compressedMessages = result.messages.map(m => {
|
||||
const msg: any = { role: m.role, content: m.content }
|
||||
|
||||
@@ -18,7 +18,7 @@ import { convertHistoryFormat } from './message-format'
|
||||
import { readSseFrames } from './sse-utils'
|
||||
import { extractResponseText } from './response-utils'
|
||||
import { applyResponseStreamEvent, flushResponseRunToDb } from './response-stream'
|
||||
import { buildCompressedHistory, getOrCreateSession } from './compression'
|
||||
import { buildCompressedHistory, buildDbHistory, buildSnapshotAwareHistory, getOrCreateSession } from './compression'
|
||||
import { calcAndUpdateUsage, estimateUsageTokensFromMessages } from './usage'
|
||||
import { handleMessage } from './message-format'
|
||||
import { countTokens, SUMMARY_PREFIX } from '../../../lib/context-compressor'
|
||||
@@ -37,6 +37,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
||||
|
||||
let inputTokens: number
|
||||
let outputTokens: number
|
||||
let contextTokens: number | undefined
|
||||
const snapshot = getCompressionSnapshot(sid)
|
||||
if (snapshot && snapshot.lastMessageIndex >= 0 && snapshot.lastMessageIndex < messages.length) {
|
||||
const newMessages = messages.slice(snapshot.lastMessageIndex + 1)
|
||||
@@ -49,6 +50,20 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
||||
inputTokens = usage.inputTokens
|
||||
outputTokens = usage.outputTokens
|
||||
}
|
||||
try {
|
||||
const session = getSession(sid)
|
||||
const dbHistory = await buildDbHistory(sid, { excludeLastUser: false })
|
||||
const snapshotHistory = await buildSnapshotAwareHistory(
|
||||
sid,
|
||||
session?.profile || 'default',
|
||||
dbHistory,
|
||||
{ model: session?.model, provider: session?.provider },
|
||||
)
|
||||
const contextUsage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||
contextTokens = contextUsage.inputTokens + contextUsage.outputTokens
|
||||
} catch (err) {
|
||||
logger.warn(err, '[chat-run-socket] failed to calculate snapshot-aware context tokens for session %s', sid)
|
||||
}
|
||||
|
||||
logger.info('[chat-run-socket] loaded session %s from DB (%d messages)', sid, messages.length)
|
||||
return {
|
||||
@@ -57,6 +72,7 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
||||
events: [],
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
contextTokens,
|
||||
queue: [],
|
||||
}
|
||||
} catch (err) {
|
||||
|
||||
@@ -16,7 +16,6 @@ import {
|
||||
contextTokensWithCachedOverhead,
|
||||
estimateUsageTokensFromMessages,
|
||||
getCachedBridgeContextOverhead,
|
||||
updateContextTokenUsage,
|
||||
updateMessageContextTokenUsage,
|
||||
} from './usage'
|
||||
import {
|
||||
@@ -123,6 +122,66 @@ function cacheBridgeContext(state: SessionState, data: Record<string, unknown> |
|
||||
}
|
||||
}
|
||||
|
||||
function bridgeContextMatches(
|
||||
state: SessionState,
|
||||
expected: { profile: string; model?: string | null; provider?: string | null },
|
||||
): boolean {
|
||||
const context = state.bridgeContext
|
||||
if (!context) return false
|
||||
if (context.profile && context.profile !== expected.profile) return false
|
||||
if (expected.model && context.model && context.model !== expected.model) return false
|
||||
if (expected.provider && context.provider && context.provider !== expected.provider) return false
|
||||
return true
|
||||
}
|
||||
|
||||
async function ensureBridgeFixedContext(args: {
|
||||
sessionId: string
|
||||
profile: string
|
||||
model?: string | null
|
||||
provider?: string | null
|
||||
instructions: string
|
||||
state: SessionState
|
||||
bridge: AgentBridgeClient
|
||||
refresh?: boolean
|
||||
}): Promise<number | undefined> {
|
||||
const cached = bridgeContextMatches(args.state, args)
|
||||
? getCachedBridgeContextOverhead(args.state)
|
||||
: undefined
|
||||
if (!args.refresh && cached != null) return cached
|
||||
|
||||
try {
|
||||
const estimate = await args.bridge.contextEstimate(
|
||||
args.sessionId,
|
||||
[],
|
||||
args.instructions,
|
||||
args.profile,
|
||||
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
|
||||
)
|
||||
cacheBridgeContext(args.state, estimate)
|
||||
const fixedContextTokens = getCachedBridgeContextOverhead(args.state)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
toolCount: estimate.tool_count,
|
||||
systemPromptChars: estimate.system_prompt_chars,
|
||||
fixedContextTokens,
|
||||
}, '[chat-run-socket] fixed context estimate')
|
||||
return fixedContextTokens
|
||||
} catch (err) {
|
||||
bridgeLogger.warn({
|
||||
err: err instanceof Error ? { message: err.message, name: err.name } : err,
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
cachedFixedContextTokens: cached,
|
||||
}, '[chat-run-socket] fixed context estimate failed')
|
||||
return cached
|
||||
}
|
||||
}
|
||||
|
||||
export async function handleBridgeRun(
|
||||
nsp: ReturnType<Server['of']>,
|
||||
socket: Socket,
|
||||
@@ -195,6 +254,9 @@ export async function handleBridgeRun(
|
||||
|
||||
const displayInput = data.display_input === undefined ? input : data.display_input
|
||||
const inputStr = displayInput == null ? '' : contentBlocksToString(displayInput)
|
||||
const actualInputStr = contentBlocksToString(input)
|
||||
const currentInputUsage = estimateUsageTokensFromMessages([{ role: 'user', content: actualInputStr }])
|
||||
const currentInputTokens = currentInputUsage.inputTokens
|
||||
const shouldPersistUserMessage = !skipUserMessage && displayInput !== null
|
||||
const displayRole = data.display_role === 'command' ? 'command' : 'user'
|
||||
let messageId: number | string | undefined
|
||||
@@ -257,33 +319,32 @@ export async function handleBridgeRun(
|
||||
emit,
|
||||
sessionMap,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
async (messages) => {
|
||||
const cachedOverhead = getCachedBridgeContextOverhead(state)
|
||||
if (cachedOverhead != null) {
|
||||
const messageUsage = estimateUsageTokensFromMessages(messages)
|
||||
return cachedOverhead + messageUsage.inputTokens + messageUsage.outputTokens
|
||||
}
|
||||
const estimate = await bridge.contextEstimate(
|
||||
session_id,
|
||||
messages,
|
||||
fullInstructions,
|
||||
async (_messages, localMessageTokens) => {
|
||||
const fixedContextTokens = await ensureBridgeFixedContext({
|
||||
sessionId: session_id,
|
||||
profile,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
)
|
||||
cacheBridgeContext(state, estimate)
|
||||
model: resolvedModel,
|
||||
provider: resolvedProvider,
|
||||
instructions: fullInstructions,
|
||||
state,
|
||||
bridge,
|
||||
refresh: true,
|
||||
})
|
||||
const contextTokens = fixedContextTokens == null
|
||||
? localMessageTokens
|
||||
: fixedContextTokens + localMessageTokens
|
||||
bridgeLogger.info({
|
||||
sessionId: session_id,
|
||||
profile,
|
||||
model: resolvedModel,
|
||||
provider: resolvedProvider,
|
||||
messages: estimate.message_count,
|
||||
toolCount: estimate.tool_count,
|
||||
systemPromptChars: estimate.system_prompt_chars,
|
||||
fixedContextTokens: estimate.fixed_context_tokens,
|
||||
fullContextTokens: estimate.token_count,
|
||||
}, '[chat-run-socket] full context estimate')
|
||||
return estimate.token_count
|
||||
fixedContextTokens,
|
||||
messageTokens: localMessageTokens,
|
||||
contextTokens,
|
||||
}, '[chat-run-socket] local context estimate')
|
||||
return contextTokens
|
||||
},
|
||||
currentInputTokens,
|
||||
)
|
||||
const bridgeHistory = history
|
||||
|
||||
@@ -349,6 +410,8 @@ export async function handleBridgeRun(
|
||||
dequeueNextQueuedRun,
|
||||
fullInstructions,
|
||||
{ model: resolvedModel, provider: resolvedProvider },
|
||||
currentInputTokens,
|
||||
shouldPersistUserMessage && displayRole === 'user',
|
||||
data.model_groups,
|
||||
)
|
||||
if (chunk.done) break
|
||||
@@ -417,61 +480,68 @@ async function refreshFinalContextUsage(args: {
|
||||
)
|
||||
const finalMessageUsage = estimateUsageTokensFromMessages(finalHistory)
|
||||
const finalMessageTokens = finalMessageUsage.inputTokens + finalMessageUsage.outputTokens
|
||||
if (getCachedBridgeContextOverhead(args.state) != null) {
|
||||
const contextTokens = updateMessageContextTokenUsage(
|
||||
args.sessionId,
|
||||
args.state,
|
||||
args.emit,
|
||||
finalMessageTokens,
|
||||
args.usage,
|
||||
)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
messages: finalHistory.length,
|
||||
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
|
||||
messageTokens: finalMessageTokens,
|
||||
fullContextTokens: contextTokens,
|
||||
}, '[chat-run-socket] final cached context estimate')
|
||||
return contextTokens
|
||||
}
|
||||
const estimate = await args.bridge.contextEstimate(
|
||||
await ensureBridgeFixedContext({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
instructions: args.instructions,
|
||||
state: args.state,
|
||||
bridge: args.bridge,
|
||||
})
|
||||
const contextTokens = updateMessageContextTokenUsage(
|
||||
args.sessionId,
|
||||
finalHistory,
|
||||
args.instructions,
|
||||
args.profile,
|
||||
{ model: args.model ?? undefined, provider: args.provider ?? undefined },
|
||||
args.state,
|
||||
args.emit,
|
||||
finalMessageTokens,
|
||||
args.usage,
|
||||
)
|
||||
cacheBridgeContext(args.state, estimate)
|
||||
const contextTokens = typeof estimate.token_count === 'number' && Number.isFinite(estimate.token_count) && estimate.token_count > 0
|
||||
? Math.floor(estimate.token_count)
|
||||
: undefined
|
||||
if (contextTokens == null) return args.state.contextTokens
|
||||
|
||||
updateContextTokenUsage(args.sessionId, args.state, args.emit, contextTokens, args.usage)
|
||||
bridgeLogger.info({
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
model: args.model,
|
||||
provider: args.provider,
|
||||
messages: estimate.message_count,
|
||||
toolCount: estimate.tool_count,
|
||||
systemPromptChars: estimate.system_prompt_chars,
|
||||
fullContextTokens: contextTokens,
|
||||
}, '[chat-run-socket] final full context estimate')
|
||||
messages: finalHistory.length,
|
||||
fixedContextTokens: args.state.bridgeContext?.fixedContextTokens,
|
||||
messageTokens: finalMessageTokens,
|
||||
contextTokens,
|
||||
}, '[chat-run-socket] final local context estimate')
|
||||
return contextTokens
|
||||
} catch (err) {
|
||||
bridgeLogger.warn({
|
||||
err: err instanceof Error ? { message: err.message, name: err.name } : err,
|
||||
sessionId: args.sessionId,
|
||||
profile: args.profile,
|
||||
}, '[chat-run-socket] final full context estimate failed')
|
||||
}, '[chat-run-socket] final local context estimate failed')
|
||||
return args.state.contextTokens
|
||||
}
|
||||
}
|
||||
|
||||
async function estimateSnapshotAwareMessageTokens(args: {
|
||||
sessionId: string
|
||||
profile: string
|
||||
model?: string | null
|
||||
provider?: string | null
|
||||
currentInputTokens?: number
|
||||
currentInputIncludedInDb?: boolean
|
||||
}): Promise<{ messageTokens: number; messages: number }> {
|
||||
const dbHistory = await buildDbHistory(args.sessionId, { excludeLastUser: false })
|
||||
const snapshotHistory = await buildSnapshotAwareHistory(
|
||||
args.sessionId,
|
||||
args.profile,
|
||||
dbHistory,
|
||||
{ model: args.model, provider: args.provider },
|
||||
)
|
||||
const usage = estimateUsageTokensFromMessages(snapshotHistory)
|
||||
const extraInputTokens = args.currentInputIncludedInDb
|
||||
? 0
|
||||
: finiteToken(args.currentInputTokens) ?? 0
|
||||
return {
|
||||
messageTokens: usage.inputTokens + usage.outputTokens + extraInputTokens,
|
||||
messages: snapshotHistory.length,
|
||||
}
|
||||
}
|
||||
|
||||
async function applyBridgeChunkAsync(
|
||||
nsp: ReturnType<Server['of']>,
|
||||
socket: Socket,
|
||||
@@ -486,6 +556,8 @@ async function applyBridgeChunkAsync(
|
||||
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
|
||||
instructions: string,
|
||||
modelContext: { model?: string | null; provider?: string | null },
|
||||
currentInputTokens = 0,
|
||||
currentInputIncludedInDb = true,
|
||||
modelGroups?: RunModelGroup[],
|
||||
): Promise<void> {
|
||||
if (state.activeRunMarker !== runMarker) {
|
||||
@@ -505,11 +577,19 @@ async function applyBridgeChunkAsync(
|
||||
if (evType === 'bridge.context.ready') {
|
||||
cacheBridgeContext(state, ev)
|
||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||
const snapshotAware = await estimateSnapshotAwareMessageTokens({
|
||||
sessionId,
|
||||
profile,
|
||||
model: modelContext.model,
|
||||
provider: modelContext.provider,
|
||||
currentInputTokens,
|
||||
currentInputIncludedInDb,
|
||||
})
|
||||
updateMessageContextTokenUsage(
|
||||
sessionId,
|
||||
state,
|
||||
emit,
|
||||
usage.inputTokens + usage.outputTokens,
|
||||
snapshotAware.messageTokens,
|
||||
usage,
|
||||
)
|
||||
} else if (evType === 'tool.started') {
|
||||
@@ -646,17 +726,22 @@ async function applyBridgeChunkAsync(
|
||||
const bridgeHistory = await buildDbHistory(sessionId, { excludeLastUser: true })
|
||||
const bridgeUsage = estimateUsageTokensFromMessages(bridgeHistory)
|
||||
const messageOnlyTokens = bridgeUsage.inputTokens + bridgeUsage.outputTokens
|
||||
const tokenCount = typeof ev.approx_tokens === 'number' && Number.isFinite(ev.approx_tokens) && ev.approx_tokens > 0
|
||||
? ev.approx_tokens
|
||||
: messageOnlyTokens
|
||||
const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||
? Math.floor(currentInputTokens)
|
||||
: 0
|
||||
const runMessageTokens = messageOnlyTokens + runInputTokens
|
||||
const tokenCount = contextTokensWithCachedOverhead(state, runMessageTokens)
|
||||
bridgeLogger.info({
|
||||
sessionId,
|
||||
profile,
|
||||
bridgeMessages: ev.message_count,
|
||||
dbMessages: bridgeHistory.length,
|
||||
messageOnlyTokens,
|
||||
fullContextTokens: tokenCount,
|
||||
source: typeof ev.approx_tokens === 'number' ? 'bridge' : 'message-only-fallback',
|
||||
currentInputTokens: runInputTokens,
|
||||
fixedContextTokens: state.bridgeContext?.fixedContextTokens,
|
||||
contextTokens: tokenCount,
|
||||
bridgeApproxTokens: ev.approx_tokens,
|
||||
source: 'local',
|
||||
}, '[chat-run-socket] bridge compression token estimate')
|
||||
const payload = {
|
||||
event: 'compression.started',
|
||||
@@ -674,7 +759,7 @@ async function applyBridgeChunkAsync(
|
||||
sessionId,
|
||||
profile,
|
||||
ev.messages as ChatMessage[],
|
||||
typeof ev.approx_tokens === 'number' ? ev.approx_tokens : undefined,
|
||||
tokenCount,
|
||||
)
|
||||
state.bridgeCompressionResults = state.bridgeCompressionResults || {}
|
||||
state.bridgeCompressionResults[String(ev.request_id)] = compressed
|
||||
@@ -689,11 +774,16 @@ async function applyBridgeChunkAsync(
|
||||
const compressionResult = ev.request_id
|
||||
? state.bridgeCompressionResults?.[String(ev.request_id)]
|
||||
: undefined
|
||||
const bridgeAfterContextTokens = finiteToken(ev.result_approx_tokens)
|
||||
const messageAfterTokens = finiteToken(compressionResult?.afterTokens)
|
||||
const afterContextTokens = messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null
|
||||
? contextTokensWithCachedOverhead(state, messageAfterTokens)
|
||||
: bridgeAfterContextTokens ?? messageAfterTokens
|
||||
const runInputTokens = typeof currentInputTokens === 'number' && Number.isFinite(currentInputTokens) && currentInputTokens > 0
|
||||
? Math.floor(currentInputTokens)
|
||||
: 0
|
||||
const messageAfterTokensWithInput = messageAfterTokens != null
|
||||
? messageAfterTokens + runInputTokens
|
||||
: undefined
|
||||
const afterContextTokens = messageAfterTokensWithInput != null
|
||||
? contextTokensWithCachedOverhead(state, messageAfterTokensWithInput)
|
||||
: undefined
|
||||
const payload = {
|
||||
event: 'compression.completed',
|
||||
run_id: chunk.run_id,
|
||||
@@ -703,7 +793,7 @@ async function applyBridgeChunkAsync(
|
||||
totalMessages: compressionResult?.beforeMessages ?? ev.message_count,
|
||||
resultMessages: compressionResult?.resultMessages ?? ev.result_messages,
|
||||
beforeTokens: compressionResult?.beforeTokens ?? ev.approx_tokens,
|
||||
afterTokens: messageAfterTokens ?? bridgeAfterContextTokens,
|
||||
afterTokens: messageAfterTokensWithInput,
|
||||
contextTokens: afterContextTokens,
|
||||
summaryTokens: compressionResult?.summaryTokens,
|
||||
verbatimCount: compressionResult?.verbatimCount,
|
||||
@@ -716,10 +806,8 @@ async function applyBridgeChunkAsync(
|
||||
replaceState(sessionMap, sessionId, 'compression.completed', payload)
|
||||
emit('compression.completed', payload)
|
||||
const usage = await calcAndUpdateUsage(sessionId, state, emit)
|
||||
if (messageAfterTokens != null && getCachedBridgeContextOverhead(state) != null) {
|
||||
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokens, usage)
|
||||
} else {
|
||||
updateContextTokenUsage(sessionId, state, emit, afterContextTokens, usage)
|
||||
if (messageAfterTokensWithInput != null) {
|
||||
updateMessageContextTokenUsage(sessionId, state, emit, messageAfterTokensWithInput, usage)
|
||||
}
|
||||
} else if (evType === 'bridge.compression.failed') {
|
||||
const payload = {
|
||||
|
||||
Reference in New Issue
Block a user