From f1a6d97c8b3593b285f6e537cf4ad6e7606f13b8 Mon Sep 17 00:00:00 2001 From: Zhicheng Han <43314240+hanzckernel@users.noreply.github.com> Date: Sun, 26 Apr 2026 04:10:01 +0200 Subject: [PATCH] fix(sessions): harden compressed session lineage projection (#226) - Project compressed roots to their continuation tip in session lists. - Search title/content candidates through logical compression lineage. - Hydrate detail views along the requested continuation branch while preserving requested ids. - Scope model-context cache lookup by provider to avoid same-name cross-provider matches. - Add regression coverage for lineage and provider lookup behavior. --- packages/server/src/db/hermes/sessions-db.ts | 354 +++++++++++------- .../src/services/hermes/model-context.ts | 121 ++++-- tests/server/model-context.test.ts | 175 ++++++--- tests/server/sessions-db-lineage.test.ts | 303 +++++++++++++++ tests/server/sessions-db.test.ts | 24 +- 5 files changed, 747 insertions(+), 230 deletions(-) create mode 100644 tests/server/sessions-db-lineage.test.ts diff --git a/packages/server/src/db/hermes/sessions-db.ts b/packages/server/src/db/hermes/sessions-db.ts index 1f9789c..08bb260 100644 --- a/packages/server/src/db/hermes/sessions-db.ts +++ b/packages/server/src/db/hermes/sessions-db.ts @@ -5,7 +5,9 @@ const SQLITE_AVAILABLE = (() => { return major > 22 || (major === 22 && minor >= 5) })() -const LINEAGE_TOLERANCE_SECONDS = 3 +const COMPRESSION_END_REASONS = new Set(['compression', 'compressed']) +const SEARCH_CANDIDATE_MULTIPLIER = 20 +const SEARCH_CANDIDATE_MIN = 100 export interface HermesSessionRow { id: string @@ -148,27 +150,6 @@ const SESSION_SELECT = ` COALESCE((SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id), s.started_at) AS last_active ` -const SESSION_FROM = ` - FROM sessions s - WHERE s.parent_session_id IS NULL - AND s.source != 'tool' -` - -function buildBaseSessionSql(source?: string): { sql: string, params: any[] } { - const sql = source - ? `SELECT ${SESSION_SELECT}${SESSION_FROM}\n AND s.source = ?` - : `SELECT ${SESSION_SELECT}${SESSION_FROM}` - return { sql, params: source ? [source] : [] } -} - -function buildListSessionSql(source?: string, limit = 2000): { sql: string, params: any[] } { - const base = buildBaseSessionSql(source) - return { - sql: `${base.sql}\n ORDER BY s.started_at DESC\n LIMIT ?`, - params: [...base.params, limit], - } -} - function containsCjk(text: string): boolean { for (const ch of text) { const cp = ch.codePointAt(0) ?? 0 @@ -251,12 +232,18 @@ function runLiteralContentSearch( query: string, limit: number, ): Record[] { - const likeBase = buildBaseSessionSql(source) const loweredQuery = query.toLowerCase() const likePattern = buildLikePattern(loweredQuery) + const sourceClause = source ? 'AND s.source = ?' : '' + const sourceParams = source ? [source] : [] const likeSql = ` WITH base AS ( - ${likeBase.sql} + SELECT + ${SESSION_SELECT}, + s.parent_session_id AS parent_session_id + FROM sessions s + WHERE s.source != 'tool' + ${sourceClause} ) SELECT base.*, @@ -273,7 +260,7 @@ function runLiteralContentSearch( ORDER BY base.last_active DESC, m.timestamp DESC LIMIT ? ` - return db.prepare(likeSql).all(...likeBase.params, loweredQuery, likePattern, limit * 4) as Record[] + return db.prepare(likeSql).all(...sourceParams, loweredQuery, likePattern, limit) as Record[] } function sanitizeFtsQuery(query: string): string { @@ -367,83 +354,185 @@ function mapMessageRow(row: Record): HermesMessageRow { } } -function timingMatchesParent(parent: HermesSessionInternalRow | undefined, child: HermesSessionInternalRow | undefined): boolean { - if (!parent || !child || parent.ended_at == null) return false - return Math.abs(Number(child.started_at || 0) - Number(parent.ended_at || 0)) <= LINEAGE_TOLERANCE_SECONDS +function isCompressionEnded(session: HermesSessionInternalRow | undefined): boolean { + return !!session && COMPRESSION_END_REASONS.has(String(session.end_reason || '')) } -function continuationCandidates( - parent: HermesSessionInternalRow, - byId: Map, - childrenByParent: Map, -): HermesSessionInternalRow[] { - return (childrenByParent.get(parent.id) || []) - .map(childId => byId.get(childId)) - .filter((child): child is HermesSessionInternalRow => !!child) - .filter(child => child.source !== 'tool') - .filter(child => child.source === parent.source) - .filter(child => timingMatchesParent(parent, child)) - .sort((a, b) => { - const aDelta = Math.abs(Number(a.started_at || 0) - Number(parent.ended_at || 0)) - const bDelta = Math.abs(Number(b.started_at || 0) - Number(parent.ended_at || 0)) - if (aDelta !== bDelta) return aDelta - bDelta - return a.id.localeCompare(b.id) - }) +function isCompressionContinuation(parent: HermesSessionInternalRow | undefined, child: HermesSessionInternalRow | undefined): boolean { + if (!parent || !child || !isCompressionEnded(parent) || parent.ended_at == null) return false + return child.source !== 'tool' && Number(child.started_at || 0) >= Number(parent.ended_at || 0) } -function normalizeComparableText(value: unknown): string { - return String(value || '').replace(/\s+/g, ' ').trim().toLowerCase() +function latestSessionInChain(chain: HermesSessionInternalRow[]): HermesSessionInternalRow { + return chain.reduce((latest, session) => { + const latestStarted = Number(latest.started_at || 0) + const sessionStarted = Number(session.started_at || 0) + if (sessionStarted !== latestStarted) return sessionStarted > latestStarted ? session : latest + return session.id.localeCompare(latest.id) > 0 ? session : latest + }, chain[0]) } -function nextContinuationChild( - parent: HermesSessionInternalRow, - byId: Map, - childrenByParent: Map, -): HermesSessionInternalRow | null { - if (parent.end_reason !== 'compression') return null - const candidates = continuationCandidates(parent, byId, childrenByParent) - if (candidates.length === 1) return candidates[0] - - const exactPreviewMatches = candidates.filter(child => { - const childPreview = normalizeComparableText(child.preview) - const parentPreview = normalizeComparableText(parent.preview) - return !!childPreview && childPreview === parentPreview - }) - return exactPreviewMatches.length === 1 ? exactPreviewMatches[0] : null -} - -function collectSessionChain( - rootId: string, - byId: Map, - childrenByParent: Map, -): HermesSessionInternalRow[] { - const chain: HermesSessionInternalRow[] = [] - const seen = new Set() - let current = byId.get(rootId) || null - while (current && !seen.has(current.id)) { - chain.push(current) - seen.add(current.id) - current = nextContinuationChild(current, byId, childrenByParent) +function projectSessionSummary(root: HermesSessionInternalRow, chain: HermesSessionInternalRow[]): HermesSessionRow { + const latest = latestSessionInChain(chain) + const { parent_session_id: _parentSessionId, ...rootRow } = root + return { + ...rootRow, + id: latest.id, + model: latest.model || root.model, + title: latest.title || root.title, + ended_at: latest.ended_at, + end_reason: latest.end_reason, + message_count: latest.message_count, + tool_call_count: latest.tool_call_count, + input_tokens: latest.input_tokens, + output_tokens: latest.output_tokens, + cache_read_tokens: latest.cache_read_tokens, + cache_write_tokens: latest.cache_write_tokens, + reasoning_tokens: latest.reasoning_tokens, + billing_provider: latest.billing_provider ?? root.billing_provider, + estimated_cost_usd: latest.estimated_cost_usd, + actual_cost_usd: latest.actual_cost_usd, + cost_status: latest.cost_status, + preview: latest.preview || root.preview, + last_active: latest.last_active || root.last_active, } - return chain } -function aggregateSessionDetail(chain: HermesSessionInternalRow[], messages: HermesMessageRow[]): HermesSessionDetailRow { +type SessionDbLike = { + prepare: (sql: string) => { all: (...params: any[]) => Record[] } +} + +function searchCandidateLimit(limit: number): number { + return Math.max(limit * SEARCH_CANDIDATE_MULTIPLIER, SEARCH_CANDIDATE_MIN) +} + +function selectSessionById(db: SessionDbLike, sessionId: string): HermesSessionInternalRow | null { + const rows = db.prepare(` + SELECT + ${SESSION_SELECT}, + s.parent_session_id AS parent_session_id + FROM sessions s + WHERE s.id = ? AND s.source != 'tool' + LIMIT 1 + `).all(sessionId) + return rows[0] ? mapInternalSessionRow(rows[0]) : null +} + +function selectLatestContinuationChildFromDb(db: SessionDbLike, parent: HermesSessionInternalRow): HermesSessionInternalRow | null { + if (!isCompressionEnded(parent) || parent.ended_at == null) return null + + const rows = db.prepare(` + SELECT + ${SESSION_SELECT}, + s.parent_session_id AS parent_session_id + FROM sessions s + WHERE s.parent_session_id = ? + AND s.source != 'tool' + AND s.started_at >= ? + ORDER BY s.started_at DESC, s.id DESC + LIMIT 1 + `).all(parent.id, parent.ended_at) + return rows[0] ? mapInternalSessionRow(rows[0]) : null +} + +function collectCompressionPathToSessionFromDb( + db: SessionDbLike, + session: HermesSessionInternalRow, +): HermesSessionInternalRow[] { + const reversed: HermesSessionInternalRow[] = [session] + const seen = new Set() + let current: HermesSessionInternalRow | null = session + + for (let depth = 0; current && current.parent_session_id && depth < 100 && !seen.has(current.id); depth += 1) { + seen.add(current.id) + const parent = selectSessionById(db, current.parent_session_id) + if (!parent || !isCompressionContinuation(parent, current)) break + reversed.push(parent) + current = parent + } + + return reversed.reverse() +} + +function extendCompressionChainFromDb( + db: SessionDbLike, + chain: HermesSessionInternalRow[], +): HermesSessionInternalRow[] { + const result = [...chain] + const seen = new Set(result.map(session => session.id)) + let current: HermesSessionInternalRow | null = result[result.length - 1] || null + + for (let depth = 0; current && depth < 100; depth += 1) { + const next = selectLatestContinuationChildFromDb(db, current) + if (!next || seen.has(next.id)) break + result.push(next) + seen.add(next.id) + current = next + } + + return result +} + +function collectSessionChainFromDb( + db: SessionDbLike, + root: HermesSessionInternalRow, +): HermesSessionInternalRow[] { + return extendCompressionChainFromDb(db, [root]) +} + +function collectSessionChainForMatchedSessionFromDb( + db: SessionDbLike, + session: HermesSessionInternalRow, +): HermesSessionInternalRow[] { + return extendCompressionChainFromDb(db, collectCompressionPathToSessionFromDb(db, session)) +} + +function projectSearchRowFromDb( + db: SessionDbLike, + row: Record, + source?: string, +): HermesSessionSearchRow | null { + const matchedSession = mapInternalSessionRow(row) + if (!matchedSession.id) return null + + const chain = collectSessionChainForMatchedSessionFromDb(db, matchedSession) const root = chain[0] - const last = chain[chain.length - 1] || root + if (!root) return null + if (source && matchedSession.source !== source) return null + + const projected = projectSessionSummary(root, chain) + return { + ...projected, + matched_message_id: normalizeNullableNumber(row.matched_message_id), + snippet: String(row.snippet || row.preview || ''), + rank: Number.isFinite(Number(row.rank)) ? Number(row.rank) : 0, + } +} + +function aggregateSessionDetail( + chain: HermesSessionInternalRow[], + messages: HermesMessageRow[], + requestedSessionId: string, +): HermesSessionDetailRow { + const root = chain[0] + const latest = latestSessionInChain(chain) const costStatuses = Array.from(new Set(chain.map(session => String(session.cost_status || '')).filter(Boolean))) const actualCosts = chain .map(session => session.actual_cost_usd) .filter((value): value is number => value != null) const firstPreview = chain.map(session => session.preview).find(Boolean) || root.preview + const { parent_session_id: _parentSessionId, ...rootRow } = root + return { - ...root, - title: root.title || (firstPreview ? (firstPreview.length > 40 ? `${firstPreview.slice(0, 40)}...` : firstPreview) : null), - preview: root.preview || firstPreview || '', - model: last.model || root.model, - ended_at: last.ended_at, - end_reason: last.end_reason, + ...rootRow, + id: requestedSessionId, + source: latest.source || root.source, + title: latest.title || root.title || (firstPreview ? (firstPreview.length > 40 ? `${firstPreview.slice(0, 40)}...` : firstPreview) : null), + preview: latest.preview || root.preview || firstPreview || '', + model: latest.model || root.model, + ended_at: latest.ended_at, + end_reason: latest.end_reason, last_active: Math.max(...chain.map(session => session.last_active || session.started_at || 0)), message_count: chain.reduce((sum, session) => sum + Number(session.message_count || 0), 0), tool_call_count: chain.reduce((sum, session) => sum + Number(session.tool_call_count || 0), 0), @@ -452,7 +541,7 @@ function aggregateSessionDetail(chain: HermesSessionInternalRow[], messages: Her cache_read_tokens: chain.reduce((sum, session) => sum + Number(session.cache_read_tokens || 0), 0), cache_write_tokens: chain.reduce((sum, session) => sum + Number(session.cache_write_tokens || 0), 0), reasoning_tokens: chain.reduce((sum, session) => sum + Number(session.reasoning_tokens || 0), 0), - billing_provider: last.billing_provider ?? root.billing_provider, + billing_provider: latest.billing_provider ?? root.billing_provider, estimated_cost_usd: chain.reduce((sum, session) => sum + Number(session.estimated_cost_usd || 0), 0), actual_cost_usd: actualCosts.length ? actualCosts.reduce((sum, value) => sum + Number(value || 0), 0) : null, cost_status: costStatuses.length === 1 ? costStatuses[0] : (costStatuses.length > 1 ? 'mixed' : ''), @@ -472,28 +561,10 @@ async function openSessionDb() { export async function getSessionDetailFromDb(sessionId: string): Promise { const db = await openSessionDb() try { - const rows = db.prepare(` - SELECT - ${SESSION_SELECT}, - s.parent_session_id AS parent_session_id - FROM sessions s - WHERE s.source != 'tool' - `).all() as Record[] + const requested = selectSessionById(db, sessionId) + if (!requested) return null - const sessions = rows.map(mapInternalSessionRow) - const byId = new Map(sessions.map(session => [session.id, session])) - const root = byId.get(sessionId) - if (!root) return null - - const childrenByParent = new Map() - for (const session of sessions) { - const key = session.parent_session_id ?? null - const siblings = childrenByParent.get(key) || [] - siblings.push(session.id) - childrenByParent.set(key, siblings) - } - - const chain = collectSessionChain(sessionId, byId, childrenByParent) + const chain = collectSessionChainForMatchedSessionFromDb(db, requested) if (!chain.length) return null const ids = chain.map(session => session.id) @@ -520,7 +591,7 @@ export async function getSessionDetailFromDb(sessionId: string): Promise[] const messages = messageRows.map(mapMessageRow) - return aggregateSessionDetail(chain, messages) + return aggregateSessionDetail(chain, messages, sessionId) } finally { db.close() } @@ -535,11 +606,29 @@ export async function listSessionSummaries(source?: string, limit = 2000): Promi const db = new DatabaseSync(sessionDbPath(), { open: true, readOnly: true }) try { - const { sql, params } = buildListSessionSql(source, limit) - const statement = db.prepare(sql) - const rows = statement.all(...params) as Record[] + const clauses = ["s.parent_session_id IS NULL", "s.source != 'tool'"] + const params: any[] = [] + if (source) { + clauses.push('s.source = ?') + params.push(source) + } + params.push(Math.max(limit * 4, limit)) - return rows.map(mapRow) + const rawRows = db.prepare(` + SELECT + ${SESSION_SELECT}, + s.parent_session_id AS parent_session_id + FROM sessions s + WHERE ${clauses.join(' AND ')} + ORDER BY s.started_at DESC + LIMIT ? + `).all(...params) as Record[] | undefined + const roots = (Array.isArray(rawRows) ? rawRows : []).map(mapInternalSessionRow) + + return roots + .map(root => projectSessionSummary(root, collectSessionChainFromDb(db, root))) + .sort((a, b) => Number(b.last_active || b.started_at || 0) - Number(a.last_active || a.started_at || 0)) + .slice(0, limit) } finally { db.close() } @@ -571,15 +660,24 @@ export async function searchSessionSummaries( const prefixQuery = toPrefixQuery(normalized) const titlePattern = buildLikePattern(normalizeTitleLikeQuery(trimmed).toLowerCase()) const useLiteralContentSearch = containsCjk(trimmed) || shouldUseLiteralContentSearch(trimmed) + const candidateLimit = searchCandidateLimit(limit) let titleRows: Record[] = [] try { - const titleBase = buildBaseSessionSql(source) - const contentBase = buildBaseSessionSql(source) + const sourceClause = source ? 'AND s.source = ?' : '' + const sourceParams = source ? [source] : [] + const allSessionsBaseSql = ` + SELECT + ${SESSION_SELECT}, + s.parent_session_id AS parent_session_id + FROM sessions s + WHERE s.source != 'tool' + ${sourceClause} + ` const titleSql = ` WITH base AS ( - ${titleBase.sql} + ${allSessionsBaseSql} ) SELECT base.*, @@ -596,11 +694,11 @@ export async function searchSessionSummaries( ` const titleStatement = db.prepare(titleSql) - titleRows = titleStatement.all(...titleBase.params, titlePattern, limit) as Record[] + titleRows = titleStatement.all(...sourceParams, titlePattern, candidateLimit) as Record[] const contentSql = ` WITH base AS ( - ${contentBase.sql} + ${allSessionsBaseSql} ) SELECT base.*, @@ -616,19 +714,19 @@ export async function searchSessionSummaries( ` const contentRows = useLiteralContentSearch - ? runLiteralContentSearch(db, source, trimmed, limit) + ? runLiteralContentSearch(db, source, trimmed, candidateLimit) : prefixQuery - ? (db.prepare(contentSql).all(...contentBase.params, prefixQuery, limit * 4) as Record[]) + ? (db.prepare(contentSql).all(...sourceParams, prefixQuery, candidateLimit) as Record[]) : [] const merged = new Map() for (const row of titleRows) { - const mapped = mapSearchRow(row) - merged.set(mapped.id, mapped) + const mapped = projectSearchRowFromDb(db, row, source) + if (mapped) merged.set(mapped.id, mapped) } for (const row of contentRows) { - const mapped = mapSearchRow(row) - if (!merged.has(mapped.id)) { + const mapped = projectSearchRowFromDb(db, row, source) + if (mapped && !merged.has(mapped.id)) { merged.set(mapped.id, mapped) } } @@ -642,15 +740,15 @@ export async function searchSessionSummaries( } catch (err) { const message = err instanceof Error ? err.message : String(err) if (containsCjk(normalized)) { - const likeRows = runLiteralContentSearch(db, source, trimmed, limit) + const likeRows = runLiteralContentSearch(db, source, trimmed, candidateLimit) const merged = new Map() for (const row of titleRows) { - const mapped = mapSearchRow(row) - merged.set(mapped.id, mapped) + const mapped = projectSearchRowFromDb(db, row, source) + if (mapped) merged.set(mapped.id, mapped) } for (const row of likeRows) { - const mapped = mapSearchRow(row) - if (!merged.has(mapped.id)) { + const mapped = projectSearchRowFromDb(db, row, source) + if (mapped && !merged.has(mapped.id)) { merged.set(mapped.id, mapped) } } diff --git a/packages/server/src/services/hermes/model-context.ts b/packages/server/src/services/hermes/model-context.ts index afa493b..252322c 100644 --- a/packages/server/src/services/hermes/model-context.ts +++ b/packages/server/src/services/hermes/model-context.ts @@ -15,6 +15,7 @@ interface ModelLimit { interface ModelEntry { id?: string + name?: string limit?: ModelLimit } @@ -22,6 +23,18 @@ interface ProviderEntry { models?: Record } +const MODEL_CACHE_PROVIDER_ALIASES: Record = { + gemini: ['google'], + moonshot: ['moonshotai'], + kilocode: ['kilo'], + 'ai-gateway': ['vercel'], + 'opencode-zen': ['opencode'], + 'opencode-go': ['opencode'], + 'glm-coding-plan': ['zai-coding-plan'], + 'kimi-coding': ['kimi-for-coding'], + 'kimi-coding-cn': ['kimi-for-coding'], +} + // --- Config YAML helpers (js-yaml) --- function loadConfig(profileDir: string): any | null { @@ -125,29 +138,79 @@ function lookupCustomProviderContextLength(config: any, modelName: string, provi // --- Context lookup --- -const CACHE_PROVIDER_ALIASES: Record = { - gemini: ['google'], - moonshot: ['moonshotai'], - kilocode: ['kilo'], - 'ai-gateway': ['vercel'], - 'opencode-zen': ['opencode'], - 'opencode-go': ['opencode'], - 'glm-coding-plan': ['zai-coding-plan'], - 'kimi-coding': ['kimi-for-coding'], - 'kimi-coding-cn': ['kimi-for-coding'], +function getCachedContext(entry: ModelEntry | undefined): number | null { + const context = entry?.limit?.context + return typeof context === 'number' && Number.isFinite(context) && context > 0 ? context : null } -function getContextFromProvider(prov: ProviderEntry | undefined, modelName: string): number | null { - const models = prov?.models || {} +function normalizeProviderKey(provider: string): string { + return provider.trim().toLowerCase() +} + +function getProviderCandidates(provider: string): string[] { + const normalized = normalizeProviderKey(provider) + return [normalized, ...(MODEL_CACHE_PROVIDER_ALIASES[normalized] || [])] +} + +function getProviderEntry(data: Record, provider: string): ProviderEntry | null { + const candidates = getProviderCandidates(provider) + + for (const candidate of candidates) { + const exact = data[candidate] + if (exact) return exact + } + + const entries = Object.entries(data) + for (const candidate of candidates) { + const match = entries.find(([name]) => name.toLowerCase() === candidate) + if (match) return match[1] + } + + return null +} + +function findModelEntry(models: Record, modelName: string): ModelEntry | undefined { const exact = models[modelName] - if (exact?.limit?.context) return exact.limit.context + if (exact) return exact const lower = modelName.toLowerCase() for (const [name, entry] of Object.entries(models)) { - if (name.toLowerCase() === lower && entry?.limit?.context) { - return entry.limit.context + if (name.toLowerCase() === lower) return entry + if (entry.id?.toLowerCase() === lower) return entry + if (entry.name?.toLowerCase() === lower) return entry + } + + const suffix = `/${lower}` + for (const [name, entry] of Object.entries(models)) { + if (name.toLowerCase().endsWith(suffix)) return entry + if (entry.id?.toLowerCase().endsWith(suffix)) return entry + } + + return undefined +} + +function lookupContextInProvider(provider: ProviderEntry | null, modelName: string): number | null { + const models = provider?.models || {} + return getCachedContext(findModelEntry(models, modelName)) +} + +function lookupContextGloballyByModelName(data: Record, modelName: string): number | null { + for (const prov of Object.values(data)) { + const context = getCachedContext(prov.models?.[modelName]) + if (context) return context + } + + const lower = modelName.toLowerCase() + for (const prov of Object.values(data)) { + const models = prov.models || {} + for (const [name, entry] of Object.entries(models)) { + if (name.toLowerCase() === lower) { + const context = getCachedContext(entry) + if (context) return context + } } } + return null } @@ -156,31 +219,11 @@ function lookupContextFromCache(modelName: string, provider: string | null): num if (!data) return null if (provider) { - const providers = [provider, ...(CACHE_PROVIDER_ALIASES[provider] || [])] - for (const providerName of providers) { - const ctx = getContextFromProvider(data[providerName], modelName) - if (ctx) return ctx - } - return null + return lookupContextInProvider(getProviderEntry(data, provider), modelName) } - // Legacy providerless lookup: exact model-name match across all providers. - for (const prov of Object.values(data)) { - const entry = prov.models?.[modelName] - if (entry?.limit?.context) return entry.limit.context - } - - // Legacy providerless case-insensitive fallback across all providers. - const lower = modelName.toLowerCase() - for (const prov of Object.values(data)) { - const models = prov.models || {} - for (const [name, entry] of Object.entries(models)) { - if (name.toLowerCase() === lower && entry?.limit?.context) { - return entry.limit.context - } - } - } - return null + // Legacy configs may omit model.provider; preserve the old global exact/CI lookup semantics. + return lookupContextGloballyByModelName(data, modelName) } /** @@ -188,7 +231,7 @@ function lookupContextFromCache(modelName: string, provider: string | null): num * Resolution order: * 1. config.yaml model.context_length (highest priority, user override) * 2. custom_providers models..context_length - * 3. models_dev_cache.json (built-in model database) + * 3. models_dev_cache.json, scoped to model.provider when configured * 4. DEFAULT_CONTEXT_LENGTH (200K hardcoded fallback) */ export function getModelContextLength(profile?: string): number { diff --git a/tests/server/model-context.test.ts b/tests/server/model-context.test.ts index d0b763e..09506fa 100644 --- a/tests/server/model-context.test.ts +++ b/tests/server/model-context.test.ts @@ -1,87 +1,160 @@ -import { mkdirSync, writeFileSync } from 'fs' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { mkdirSync, mkdtempSync, rmSync, writeFileSync } from 'fs' import { join } from 'path' import { tmpdir } from 'os' -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' -function makeHome() { - const root = join(tmpdir(), `wui-model-context-${Date.now()}-${Math.random().toString(36).slice(2)}`) - const hermes = join(root, '.hermes') - mkdirSync(hermes, { recursive: true }) - return { root, hermes } +let homeDir = '' + +function hermesPath(...parts: string[]) { + return join(homeDir, '.hermes', ...parts) } -function writeConfig(hermes: string, yaml: string) { - writeFileSync(join(hermes, 'config.yaml'), yaml) +function writeConfig(content: string) { + mkdirSync(hermesPath(), { recursive: true }) + writeFileSync(hermesPath('config.yaml'), content) } -function writeModelsCache(hermes: string) { - writeFileSync(join(hermes, 'models_dev_cache.json'), JSON.stringify({ - openai: { - models: { - 'gpt-5.5': { limit: { context: 1_050_000 } }, - 'gpt-5.4': { limit: { context: 1_050_000 } }, - }, - }, - google: { - models: { - 'gemini-3.1-pro-preview': { limit: { context: 1_000_000 } }, - }, - }, - })) +function writeModelsCache(data: Record) { + mkdirSync(hermesPath(), { recursive: true }) + writeFileSync(hermesPath('models_dev_cache.json'), JSON.stringify(data)) } -async function importContextService(home: string) { +async function loadModelContext() { vi.resetModules() - vi.stubEnv('HOME', home) - return await import('../../packages/server/src/services/hermes/model-context') + vi.doMock('os', async () => ({ + ...(await vi.importActual('os')), + homedir: () => homeDir, + })) + return import('../../packages/server/src/services/hermes/model-context') } -describe('model context length resolution', () => { +describe('getModelContextLength', () => { beforeEach(() => { - vi.unstubAllEnvs() + homeDir = mkdtempSync(join(tmpdir(), 'hwui-model-context-')) }) afterEach(() => { - vi.unstubAllEnvs() - vi.resetModules() + vi.doUnmock('os') + if (homeDir) rmSync(homeDir, { recursive: true, force: true }) + homeDir = '' }) - it('does not borrow OpenAI context metadata for an openai-codex model with the same name', async () => { - const { root, hermes } = makeHome() - writeConfig(hermes, 'model:\n provider: openai-codex\n default: gpt-5.5\n') - writeModelsCache(hermes) + it('does not borrow a same-named model context from another provider when the configured provider is uncached', async () => { + writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n`) + writeModelsCache({ + openai: { + models: { + 'gpt-5.5': { limit: { context: 1_050_000 } }, + }, + }, + }) - const { getModelContextLength } = await importContextService(root) + const { getModelContextLength } = await loadModelContext() expect(getModelContextLength()).toBe(200_000) }) - it('still honors explicit model.context_length before provider-aware cache lookup', async () => { - const { root, hermes } = makeHome() - writeConfig(hermes, 'model:\n provider: openai-codex\n default: gpt-5.5\n context_length: 272000\n') - writeModelsCache(hermes) + it('does not scan other providers when the configured provider exists without that model', async () => { + writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n`) + writeModelsCache({ + 'openai-codex': { + models: { + 'gpt-5.4': { limit: { context: 200_000 } }, + }, + }, + openai: { + models: { + 'gpt-5.5': { limit: { context: 1_050_000 } }, + }, + }, + }) - const { getModelContextLength } = await importContextService(root) + const { getModelContextLength } = await loadModelContext() - expect(getModelContextLength()).toBe(272_000) + expect(getModelContextLength()).toBe(200_000) }) - it('preserves providerless legacy lookup by model name', async () => { - const { root, hermes } = makeHome() - writeConfig(hermes, 'model:\n default: gpt-5.5\n') - writeModelsCache(hermes) + it('uses the configured provider cache entry when the provider matches', async () => { + writeConfig(`model:\n default: gpt-5.5\n provider: openai\n`) + writeModelsCache({ + openai: { + models: { + 'gpt-5.5': { limit: { context: 1_050_000 } }, + }, + }, + }) - const { getModelContextLength } = await importContextService(root) + const { getModelContextLength } = await loadModelContext() expect(getModelContextLength()).toBe(1_050_000) }) - it('uses intentional cache provider aliases without conflating openai-codex with openai', async () => { - const { root, hermes } = makeHome() - writeConfig(hermes, 'model:\n provider: gemini\n default: gemini-3.1-pro-preview\n') - writeModelsCache(hermes) + it('keeps legacy model-name cache lookup when no provider is configured', async () => { + writeConfig(`model:\n default: gpt-5.5\n`) + writeModelsCache({ + openai: { + models: { + 'gpt-5.5': { limit: { context: 1_050_000 } }, + }, + }, + }) - const { getModelContextLength } = await importContextService(root) + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(1_050_000) + }) + + it('keeps providerless legacy lookup on global exact matches before prefixed suffix matches', async () => { + writeConfig(`model:\n default: gpt-5\n`) + writeModelsCache({ + vercel: { + models: { + 'openai/gpt-5': { limit: { context: 1_000_000 } }, + }, + }, + openai: { + models: { + 'gpt-5': { limit: { context: 400_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(400_000) + }) + + it('maps WUI provider keys to model-cache provider keys before looking up limits', async () => { + writeConfig(`model:\n default: gemini-3.1-pro-preview\n provider: gemini\n`) + writeModelsCache({ + google: { + models: { + 'gemini-3.1-pro-preview': { limit: { context: 1_000_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(1_000_000) + }) + + it('uses gateway provider aliases with prefixed model names inside the aliased provider only', async () => { + writeConfig(`model:\n default: openai/gpt-5\n provider: ai-gateway\n`) + writeModelsCache({ + vercel: { + models: { + 'openai/gpt-5': { limit: { context: 1_000_000 } }, + }, + }, + openai: { + models: { + 'gpt-5': { limit: { context: 400_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() expect(getModelContextLength()).toBe(1_000_000) }) diff --git a/tests/server/sessions-db-lineage.test.ts b/tests/server/sessions-db-lineage.test.ts new file mode 100644 index 0000000..b27a40a --- /dev/null +++ b/tests/server/sessions-db-lineage.test.ts @@ -0,0 +1,303 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { mkdtempSync, rmSync } from 'fs' +import { tmpdir } from 'os' +import { join } from 'path' +import { DatabaseSync } from 'node:sqlite' + +const profileDir = vi.hoisted(() => ({ value: '' })) + +vi.mock('../../packages/server/src/services/hermes/hermes-profile', () => ({ + getActiveProfileDir: () => profileDir.value, +})) + +function createStateDb(path: string) { + const db = new DatabaseSync(path) + db.exec(` + CREATE TABLE sessions ( + id TEXT PRIMARY KEY, + source TEXT NOT NULL, + user_id TEXT, + model TEXT, + title TEXT, + started_at REAL, + ended_at REAL, + end_reason TEXT, + message_count INTEGER, + tool_call_count INTEGER, + input_tokens INTEGER, + output_tokens INTEGER, + cache_read_tokens INTEGER, + cache_write_tokens INTEGER, + reasoning_tokens INTEGER, + billing_provider TEXT, + estimated_cost_usd REAL, + actual_cost_usd REAL, + cost_status TEXT, + parent_session_id TEXT + ); + + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + session_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT, + tool_call_id TEXT, + tool_calls TEXT, + tool_name TEXT, + timestamp REAL, + token_count INTEGER, + finish_reason TEXT, + reasoning TEXT, + reasoning_details TEXT, + codex_reasoning_items TEXT, + reasoning_content TEXT + ); + + CREATE VIRTUAL TABLE messages_fts USING fts5(content); + `) + return db +} + +function insertSession( + db: DatabaseSync, + row: { + id: string + source?: string + parent_session_id?: string | null + title?: string + started_at: number + ended_at?: number | null + end_reason?: string | null + message_count?: number + model?: string + }, +) { + db.prepare(` + INSERT INTO sessions ( + id, source, user_id, model, title, started_at, ended_at, end_reason, + message_count, tool_call_count, input_tokens, output_tokens, + cache_read_tokens, cache_write_tokens, reasoning_tokens, billing_provider, + estimated_cost_usd, actual_cost_usd, cost_status, parent_session_id + ) VALUES (?, ?, '', ?, ?, ?, ?, ?, ?, 0, 0, 0, 0, 0, 0, '', 0, NULL, '', ?) + `).run( + row.id, + row.source || 'api_server', + row.model || 'gpt-5.5', + row.title || '', + row.started_at, + row.ended_at ?? null, + row.end_reason ?? null, + row.message_count ?? 1, + row.parent_session_id ?? null, + ) +} + +function insertMessage( + db: DatabaseSync, + row: { + id: number + session_id: string + role?: string + content: string + timestamp: number + }, +) { + db.prepare(` + INSERT INTO messages ( + id, session_id, role, content, tool_call_id, tool_calls, tool_name, + timestamp, token_count, finish_reason, reasoning, reasoning_details, + codex_reasoning_items, reasoning_content + ) VALUES (?, ?, ?, ?, NULL, NULL, NULL, ?, NULL, NULL, NULL, NULL, NULL, NULL) + `).run(row.id, row.session_id, row.role || 'user', row.content, row.timestamp) + db.prepare('INSERT INTO messages_fts(rowid, content) VALUES (?, ?)').run(row.id, row.content) +} + +function seedCompressionChain(db: DatabaseSync) { + insertSession(db, { + id: 'root', + source: 'api_server', + title: 'Mermaid fix', + started_at: 100, + ended_at: 200, + end_reason: 'compression', + message_count: 2, + }) + insertSession(db, { + id: 'middle', + source: 'cli', + parent_session_id: 'root', + title: 'Mermaid fix #2', + started_at: 201, + ended_at: 300, + end_reason: 'compression', + message_count: 3, + }) + insertSession(db, { + id: 'tip', + source: 'cli', + parent_session_id: 'middle', + title: 'Mermaid fix #3', + started_at: 301, + ended_at: null, + end_reason: null, + message_count: 4, + }) + + insertMessage(db, { id: 1, session_id: 'root', content: 'root turn', timestamp: 101 }) + insertMessage(db, { id: 2, session_id: 'middle', content: 'middle turn', timestamp: 202 }) + insertMessage(db, { id: 3, session_id: 'tip', content: 'tip lineageunique turn', timestamp: 302 }) +} + +describe('session DB compression lineage', () => { + let tempDir = '' + let db: DatabaseSync | null = null + + beforeEach(() => { + vi.resetModules() + tempDir = mkdtempSync(join(tmpdir(), 'wui-session-lineage-')) + profileDir.value = tempDir + db = createStateDb(join(tempDir, 'state.db')) + }) + + afterEach(() => { + db?.close() + db = null + if (tempDir) rmSync(tempDir, { recursive: true, force: true }) + }) + + it('projects compressed root summaries to the latest continuation tip', async () => { + seedCompressionChain(db!) + + const mod = await import('../../packages/server/src/db/hermes/sessions-db') + const rows = await mod.listSessionSummaries(undefined, 20) + + expect(rows).toHaveLength(1) + expect(rows[0]).toMatchObject({ + id: 'tip', + title: 'Mermaid fix #3', + message_count: 4, + end_reason: null, + preview: 'tip lineageunique turn', + started_at: 100, + }) + }) + + it('returns the projected logical session when search matches continuation content', async () => { + seedCompressionChain(db!) + + const mod = await import('../../packages/server/src/db/hermes/sessions-db') + const rows = await mod.searchSessionSummaries('lineageunique', undefined, 20) + + expect(rows).toHaveLength(1) + expect(rows[0]).toMatchObject({ + id: 'tip', + title: 'Mermaid fix #3', + matched_message_id: 3, + }) + expect(rows[0].snippet).toContain('lineageunique') + }) + + it('hydrates the full compression chain when detail is requested by projected tip id', async () => { + seedCompressionChain(db!) + + const mod = await import('../../packages/server/src/db/hermes/sessions-db') + const detail = await mod.getSessionDetailFromDb('tip') + + expect(detail).toMatchObject({ + id: 'tip', + title: 'Mermaid fix #3', + message_count: 9, + thread_session_count: 3, + }) + expect(detail?.messages.map(message => message.session_id)).toEqual(['root', 'middle', 'tip']) + }) + + it('follows only the latest compression continuation child when a parent has multiple children', async () => { + insertSession(db!, { + id: 'root', + started_at: 100, + ended_at: 200, + end_reason: 'compression', + message_count: 1, + }) + insertSession(db!, { + id: 'older-child', + parent_session_id: 'root', + title: 'Older branch', + started_at: 201, + ended_at: null, + end_reason: null, + message_count: 1, + }) + insertSession(db!, { + id: 'latest-child', + parent_session_id: 'root', + title: 'Latest branch', + started_at: 205, + ended_at: null, + end_reason: null, + message_count: 1, + }) + insertMessage(db!, { id: 11, session_id: 'root', content: 'root', timestamp: 101 }) + insertMessage(db!, { id: 12, session_id: 'older-child', content: 'older should not merge', timestamp: 202 }) + insertMessage(db!, { id: 13, session_id: 'latest-child', content: 'latest should merge', timestamp: 206 }) + + const mod = await import('../../packages/server/src/db/hermes/sessions-db') + const detail = await mod.getSessionDetailFromDb('root') + + expect(detail).toMatchObject({ + id: 'root', + title: 'Latest branch', + message_count: 2, + thread_session_count: 2, + }) + expect(detail?.messages.map(message => message.session_id)).toEqual(['root', 'latest-child']) + + const olderDetail = await mod.getSessionDetailFromDb('older-child') + expect(olderDetail).toMatchObject({ + id: 'older-child', + title: 'Older branch', + message_count: 2, + thread_session_count: 2, + }) + expect(olderDetail?.messages.map(message => message.session_id)).toEqual(['root', 'older-child']) + + const olderSearch = await mod.searchSessionSummaries('older should', undefined, 20) + expect(olderSearch[0]).toMatchObject({ + id: 'older-child', + title: 'Older branch', + matched_message_id: 12, + }) + }) + + it('applies source filters before search candidate limiting', async () => { + for (let index = 0; index < 105; index += 1) { + insertSession(db!, { + id: `cli-${index}`, + source: 'cli', + title: `needle cli ${index}`, + started_at: 1000 + index, + ended_at: null, + end_reason: null, + }) + } + insertSession(db!, { + id: 'telegram-match', + source: 'telegram', + title: 'needle telegram target', + started_at: 10, + ended_at: null, + end_reason: null, + }) + + const mod = await import('../../packages/server/src/db/hermes/sessions-db') + const rows = await mod.searchSessionSummaries('needle', 'telegram', 1) + + expect(rows).toHaveLength(1) + expect(rows[0]).toMatchObject({ + id: 'telegram-match', + source: 'telegram', + title: 'needle telegram target', + }) + }) +}) diff --git a/tests/server/sessions-db.test.ts b/tests/server/sessions-db.test.ts index d53531f..3ffb8eb 100644 --- a/tests/server/sessions-db.test.ts +++ b/tests/server/sessions-db.test.ts @@ -67,8 +67,8 @@ describe('session DB summaries', () => { const rows = await mod.listSessionSummaries(undefined, 50) expect(databaseSyncMock).toHaveBeenCalledWith('/tmp/hermes-profile/state.db', { open: true, readOnly: true }) - expect(prepareMock).toHaveBeenCalledWith(expect.stringContaining("AND s.source != 'tool'")) - expect(allMock).toHaveBeenCalledWith(50) + expect(prepareMock).toHaveBeenCalledWith(expect.stringContaining("s.source != 'tool'")) + expect(allMock).toHaveBeenCalledWith(200) expect(closeMock).toHaveBeenCalled() expect(rows).toEqual([ { @@ -127,8 +127,8 @@ describe('session DB summaries', () => { const mod = await import('../../packages/server/src/db/hermes/sessions-db') const rows = await mod.listSessionSummaries('telegram', 2) - expect(prepareMock).toHaveBeenCalledWith(expect.stringContaining('AND s.source = ?')) - expect(allMock).toHaveBeenCalledWith('telegram', 2) + expect(prepareMock).toHaveBeenCalledWith(expect.stringContaining("s.source != 'tool'")) + expect(allMock).toHaveBeenCalledWith('telegram', 8) expect(rows[0].last_active).toBe(1710000100) expect(rows[0].source).toBe('telegram') expect(rows[0].title).toBe('preview text') @@ -375,8 +375,8 @@ describe('session DB summaries', () => { const mod = await import('../../packages/server/src/db/hermes/sessions-db') const rows = await mod.searchSessionSummaries('node.js*', undefined, 10) - expect(titleAllMock).toHaveBeenCalledWith('%node.js%', 10) - expect(contentAllMock).toHaveBeenCalledWith('"node.js"*', 40) + expect(titleAllMock).toHaveBeenCalledWith('%node.js%', 200) + expect(contentAllMock).toHaveBeenCalledWith('"node.js"*', 200) expect(likeAllMock).not.toHaveBeenCalled() expect(rows).toHaveLength(2) expect(rows[0].id).toBe('node-wildcard-title-1') @@ -444,8 +444,8 @@ describe('session DB summaries', () => { const mod = await import('../../packages/server/src/db/hermes/sessions-db') const rows = await mod.searchSessionSummaries('"node.js"*', undefined, 10) - expect(titleAllMock).toHaveBeenCalledWith('%node.js%', 10) - expect(contentAllMock).toHaveBeenCalledWith('"node.js"*', 40) + expect(titleAllMock).toHaveBeenCalledWith('%node.js%', 200) + expect(contentAllMock).toHaveBeenCalledWith('"node.js"*', 200) expect(likeAllMock).not.toHaveBeenCalled() expect(rows).toHaveLength(2) expect(rows[0].id).toBe('node-quoted-title-1') @@ -486,7 +486,7 @@ describe('session DB summaries', () => { const mod = await import('../../packages/server/src/db/hermes/sessions-db') const rows = await mod.searchSessionSummaries('naïve.js', undefined, 10) - expect(contentAllMock).toHaveBeenCalledWith('"naïve.js"', 40) + expect(contentAllMock).toHaveBeenCalledWith('"naïve.js"', 200) expect(likeAllMock).not.toHaveBeenCalled() expect(rows).toHaveLength(1) expect(rows[0].id).toBe('unicode-dot-1') @@ -526,7 +526,7 @@ describe('session DB summaries', () => { const mod = await import('../../packages/server/src/db/hermes/sessions-db') const rows = await mod.searchSessionSummaries('100%', undefined, 10) - expect(titleAllMock).toHaveBeenCalledWith('%100\\%%', 10) + expect(titleAllMock).toHaveBeenCalledWith('%100\\%%', 200) expect(rows).toHaveLength(1) expect(rows[0].id).toBe('percent-1') }) @@ -567,7 +567,7 @@ describe('session DB summaries', () => { const rows = await mod.searchSessionSummaries('记忆断裂', undefined, 10) expect(contentAllMock).not.toHaveBeenCalled() - expect(likeAllMock).toHaveBeenCalledWith('记忆断裂', '%记忆断裂%', 40) + expect(likeAllMock).toHaveBeenCalledWith('记忆断裂', '%记忆断裂%', 200) expect(rows).toHaveLength(1) expect(rows[0].id).toBe('cjk-literal-1') }) @@ -636,7 +636,7 @@ describe('session DB summaries', () => { const mod = await import('../../packages/server/src/db/hermes/sessions-db') const rows = await mod.searchSessionSummaries('记忆断裂', undefined, 10) - expect(likeAllMock).toHaveBeenCalledWith('记忆断裂', '%记忆断裂%', 40) + expect(likeAllMock).toHaveBeenCalledWith('记忆断裂', '%记忆断裂%', 200) expect(rows).toHaveLength(2) expect(rows[0].id).toBe('cjk-1') expect(rows[1].id).toBe('cjk-title-1')