Add session-level bridge model settings (#811)
This commit is contained in:
@@ -5,10 +5,11 @@
|
||||
|
||||
import type { Server, Socket } from 'socket.io'
|
||||
import { getSystemPrompt } from '../../../lib/llm-prompt'
|
||||
import { getSession, createSession, addMessage, updateSessionStats } from '../../../db/hermes/session-store'
|
||||
import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store'
|
||||
import { updateUsage } from '../../../db/hermes/usage-store'
|
||||
import { logger, bridgeLogger } from '../../logger'
|
||||
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
|
||||
import { readConfigYaml } from '../../config-helpers'
|
||||
import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks'
|
||||
import { buildCompressedHistory } from './compression'
|
||||
import { pushState, replaceState } from './compression'
|
||||
@@ -29,10 +30,34 @@ import type { ChatMessage } from '../../../lib/context-compressor'
|
||||
|
||||
const BRIDGE_USAGE_FLUSH_DELAY_MS = 200
|
||||
|
||||
type RunModelGroup = { provider: string; models: string[] }
|
||||
|
||||
async function resolveDefaultModelConfig(): Promise<{ model: string; provider: string }> {
|
||||
try {
|
||||
const config = await readConfigYaml()
|
||||
const modelConfig = config?.model
|
||||
const model = typeof modelConfig === 'string'
|
||||
? modelConfig.trim()
|
||||
: String(modelConfig?.default || '').trim()
|
||||
const provider = typeof modelConfig === 'object'
|
||||
? String(modelConfig?.provider || '').trim()
|
||||
: ''
|
||||
return { model, provider }
|
||||
} catch {
|
||||
return { model: '', provider: '' }
|
||||
}
|
||||
}
|
||||
|
||||
function hasModelInGroups(groups: RunModelGroup[] | undefined, provider: string, model: string): boolean {
|
||||
if (!groups?.length || !provider || !model) return false
|
||||
const group = groups.find(item => item.provider === provider)
|
||||
return Array.isArray(group?.models) && group.models.includes(model)
|
||||
}
|
||||
|
||||
export async function handleBridgeRun(
|
||||
nsp: ReturnType<Server['of']>,
|
||||
socket: Socket,
|
||||
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
|
||||
data: { input: string | ContentBlock[]; session_id?: string; model?: string; provider?: string; model_groups?: RunModelGroup[]; instructions?: string; source?: string },
|
||||
profile: string,
|
||||
sessionMap: Map<string, SessionState>,
|
||||
gatewayManager: any,
|
||||
@@ -41,7 +66,7 @@ export async function handleBridgeRun(
|
||||
loadSessionStateFromDbFn: (sid: string, sessionMap: Map<string, SessionState>) => Promise<SessionState>,
|
||||
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
|
||||
) {
|
||||
const { input, session_id, model, instructions } = data
|
||||
const { input, session_id, instructions } = data
|
||||
if (!session_id) {
|
||||
socket.emit('run.failed', { event: 'run.failed', error: 'session_id is required for cli source' })
|
||||
return
|
||||
@@ -51,6 +76,22 @@ export async function handleBridgeRun(
|
||||
? `${getSystemPrompt()}\n${instructions}`
|
||||
: getSystemPrompt()
|
||||
const sessionRow = getSession(session_id)
|
||||
const sessionModel = sessionRow?.model || ''
|
||||
const sessionProvider = sessionRow?.provider || ''
|
||||
const hasGroups = Array.isArray(data.model_groups) && data.model_groups.length > 0
|
||||
const sessionModelAvailable = hasGroups && hasModelInGroups(data.model_groups, sessionProvider, sessionModel)
|
||||
const shouldUseDefault = !sessionModel || !sessionProvider || !sessionModelAvailable
|
||||
const defaultModelConfig = shouldUseDefault
|
||||
? await resolveDefaultModelConfig()
|
||||
: { model: '', provider: '' }
|
||||
const resolvedModel = shouldUseDefault ? defaultModelConfig.model : sessionModel
|
||||
const resolvedProvider = shouldUseDefault ? defaultModelConfig.provider : sessionProvider
|
||||
if (sessionRow) {
|
||||
const updates: { model?: string; provider?: string } = {}
|
||||
if (resolvedModel && sessionRow.model !== resolvedModel) updates.model = resolvedModel
|
||||
if (resolvedProvider && sessionRow.provider !== resolvedProvider) updates.provider = resolvedProvider
|
||||
if (Object.keys(updates).length > 0) updateSession(session_id, updates)
|
||||
}
|
||||
if (sessionRow?.workspace) {
|
||||
const workspaceCtx = `[Current working directory: ${sessionRow.workspace}]`
|
||||
fullInstructions = `\n${workspaceCtx}\n${fullInstructions}`
|
||||
@@ -93,7 +134,7 @@ export async function handleBridgeRun(
|
||||
if (!getSession(session_id)) {
|
||||
const previewText = extractTextForPreview(input)
|
||||
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
|
||||
createSession({ id: session_id, profile, source: 'cli', model, title: preview })
|
||||
createSession({ id: session_id, profile, source: 'cli', model: resolvedModel, provider: resolvedProvider, title: preview })
|
||||
}
|
||||
addMessage({
|
||||
session_id,
|
||||
@@ -142,7 +183,11 @@ export async function handleBridgeRun(
|
||||
bridgeHistory,
|
||||
fullInstructions,
|
||||
profile,
|
||||
bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {},
|
||||
{
|
||||
...(bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {}),
|
||||
...(resolvedModel ? { model: resolvedModel } : {}),
|
||||
...(resolvedProvider ? { provider: resolvedProvider } : {}),
|
||||
},
|
||||
)
|
||||
state.runId = started.run_id
|
||||
bridgeLogger.info({
|
||||
|
||||
Reference in New Issue
Block a user