Add session-level bridge model settings (#811)

This commit is contained in:
ekko
2026-05-17 12:20:53 +08:00
committed by GitHub
parent fa035f348e
commit 5e8f8bd4a1
35 changed files with 697 additions and 60 deletions
@@ -5,10 +5,11 @@
import type { Server, Socket } from 'socket.io'
import { getSystemPrompt } from '../../../lib/llm-prompt'
import { getSession, createSession, addMessage, updateSessionStats } from '../../../db/hermes/session-store'
import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store'
import { updateUsage } from '../../../db/hermes/usage-store'
import { logger, bridgeLogger } from '../../logger'
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
import { readConfigYaml } from '../../config-helpers'
import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks'
import { buildCompressedHistory } from './compression'
import { pushState, replaceState } from './compression'
@@ -29,10 +30,34 @@ import type { ChatMessage } from '../../../lib/context-compressor'
const BRIDGE_USAGE_FLUSH_DELAY_MS = 200
type RunModelGroup = { provider: string; models: string[] }
async function resolveDefaultModelConfig(): Promise<{ model: string; provider: string }> {
try {
const config = await readConfigYaml()
const modelConfig = config?.model
const model = typeof modelConfig === 'string'
? modelConfig.trim()
: String(modelConfig?.default || '').trim()
const provider = typeof modelConfig === 'object'
? String(modelConfig?.provider || '').trim()
: ''
return { model, provider }
} catch {
return { model: '', provider: '' }
}
}
function hasModelInGroups(groups: RunModelGroup[] | undefined, provider: string, model: string): boolean {
if (!groups?.length || !provider || !model) return false
const group = groups.find(item => item.provider === provider)
return Array.isArray(group?.models) && group.models.includes(model)
}
export async function handleBridgeRun(
nsp: ReturnType<Server['of']>,
socket: Socket,
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
data: { input: string | ContentBlock[]; session_id?: string; model?: string; provider?: string; model_groups?: RunModelGroup[]; instructions?: string; source?: string },
profile: string,
sessionMap: Map<string, SessionState>,
gatewayManager: any,
@@ -41,7 +66,7 @@ export async function handleBridgeRun(
loadSessionStateFromDbFn: (sid: string, sessionMap: Map<string, SessionState>) => Promise<SessionState>,
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
) {
const { input, session_id, model, instructions } = data
const { input, session_id, instructions } = data
if (!session_id) {
socket.emit('run.failed', { event: 'run.failed', error: 'session_id is required for cli source' })
return
@@ -51,6 +76,22 @@ export async function handleBridgeRun(
? `${getSystemPrompt()}\n${instructions}`
: getSystemPrompt()
const sessionRow = getSession(session_id)
const sessionModel = sessionRow?.model || ''
const sessionProvider = sessionRow?.provider || ''
const hasGroups = Array.isArray(data.model_groups) && data.model_groups.length > 0
const sessionModelAvailable = hasGroups && hasModelInGroups(data.model_groups, sessionProvider, sessionModel)
const shouldUseDefault = !sessionModel || !sessionProvider || !sessionModelAvailable
const defaultModelConfig = shouldUseDefault
? await resolveDefaultModelConfig()
: { model: '', provider: '' }
const resolvedModel = shouldUseDefault ? defaultModelConfig.model : sessionModel
const resolvedProvider = shouldUseDefault ? defaultModelConfig.provider : sessionProvider
if (sessionRow) {
const updates: { model?: string; provider?: string } = {}
if (resolvedModel && sessionRow.model !== resolvedModel) updates.model = resolvedModel
if (resolvedProvider && sessionRow.provider !== resolvedProvider) updates.provider = resolvedProvider
if (Object.keys(updates).length > 0) updateSession(session_id, updates)
}
if (sessionRow?.workspace) {
const workspaceCtx = `[Current working directory: ${sessionRow.workspace}]`
fullInstructions = `\n${workspaceCtx}\n${fullInstructions}`
@@ -93,7 +134,7 @@ export async function handleBridgeRun(
if (!getSession(session_id)) {
const previewText = extractTextForPreview(input)
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
createSession({ id: session_id, profile, source: 'cli', model, title: preview })
createSession({ id: session_id, profile, source: 'cli', model: resolvedModel, provider: resolvedProvider, title: preview })
}
addMessage({
session_id,
@@ -142,7 +183,11 @@ export async function handleBridgeRun(
bridgeHistory,
fullInstructions,
profile,
bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {},
{
...(bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {}),
...(resolvedModel ? { model: resolvedModel } : {}),
...(resolvedProvider ? { provider: resolvedProvider } : {}),
},
)
state.runId = started.run_id
bridgeLogger.info({