Add session-level bridge model settings (#811)

This commit is contained in:
ekko
2026-05-17 12:20:53 +08:00
committed by GitHub
parent fa035f348e
commit 5e8f8bd4a1
35 changed files with 697 additions and 60 deletions
@@ -28,6 +28,13 @@ export interface AgentBridgeRequestOptions {
timeoutMs?: number
}
export interface AgentBridgeChatOptions {
force_compress?: boolean
storage_message?: AgentBridgeMessage
model?: string
provider?: string
}
export type AgentBridgeMessage =
| string
| Array<Record<string, unknown>>
@@ -306,7 +313,7 @@ export class AgentBridgeClient {
conversationHistory?: unknown[],
instructions?: string,
profile?: string,
options: { force_compress?: boolean; storage_message?: AgentBridgeMessage } = {},
options: AgentBridgeChatOptions = {},
): Promise<AgentBridgeChatStarted> {
return this.request<AgentBridgeChatStarted>({
action: 'chat',
@@ -316,6 +323,8 @@ export class AgentBridgeClient {
...(conversationHistory ? { conversation_history: conversationHistory } : {}),
...(instructions ? { instructions } : {}),
...(profile ? { profile } : {}),
...(options.model ? { model: options.model } : {}),
...(options.provider ? { provider: options.provider } : {}),
...(options.force_compress ? { force_compress: true } : {}),
})
}
@@ -491,12 +491,21 @@ class AgentPool:
self,
session_id: str,
profile: str | None = None,
model: str | None = None,
provider: str | None = None,
) -> AgentSession:
requested_model = str(model or "").strip()
requested_provider = str(provider or "").strip()
with self._lock:
existing = self._sessions.get(session_id)
if existing is not None:
# If profile changed, destroy old session and recreate
if profile and existing.config.get("profile") != profile:
config_changed = bool(
(profile and existing.config.get("profile") != profile)
or (requested_model and existing.config.get("model") != requested_model)
or (requested_provider and existing.config.get("provider") != requested_provider)
)
if config_changed:
if not existing.running:
self._destroy_session(session_id)
else:
@@ -512,8 +521,8 @@ class AgentPool:
with _profile_env(profile):
cfg = _load_cfg()
resolved_model = _resolve_model(cfg)
runtime = _resolve_runtime(resolved_model)
resolved_model = requested_model or _resolve_model(cfg)
runtime = _resolve_runtime(resolved_model, requested_provider or None)
agent_cfg = cfg.get("agent") or {}
prompt = str(agent_cfg.get("system_prompt", "") or "").strip() or None
@@ -949,8 +958,10 @@ class AgentPool:
conversation_history: list[dict[str, Any]] | None = None,
profile: str | None = None,
force_compress: bool = False,
model: str | None = None,
provider: str | None = None,
) -> RunRecord:
session = self.get_or_create(session_id, profile=profile)
session = self.get_or_create(session_id, profile=profile, model=model, provider=provider)
with session.lock:
if session.running:
raise RuntimeError(f"session {session_id} is already running")
@@ -1265,6 +1276,8 @@ class BridgeServer:
instructions = req.get("instructions") or req.get("system_message")
conversation_history = req.get("conversation_history")
profile = req.get("profile")
model = req.get("model")
provider = req.get("provider")
record = self.pool.start_chat(
session_id,
message,
@@ -1273,6 +1286,8 @@ class BridgeServer:
conversation_history,
profile,
bool(req.get("force_compress")),
model,
provider,
)
if req.get("wait"):
timeout = float(req.get("timeout", 0) or 0)
@@ -75,14 +75,14 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
export async function handleApiRun(
nsp: ReturnType<Server['of']>,
socket: Socket,
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
data: { input: string | ContentBlock[]; session_id?: string; model?: string; provider?: string; instructions?: string; source?: string },
profile: string,
sessionMap: Map<string, SessionState>,
gatewayManager: any,
skipUserMessage = false,
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
) {
const { input, session_id, model, instructions } = data
const { input, session_id, model, provider, instructions } = data
// Build full instructions with system prompt + workspace context
let fullInstructions = instructions
@@ -131,7 +131,7 @@ export async function handleApiRun(
if (!getSession(session_id)) {
const previewText = extractTextForPreview(input)
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
createSession({ id: session_id, profile, source: 'api_server', model, title: preview })
createSession({ id: session_id, profile, source: 'api_server', model, provider, title: preview })
}
addMessage({
@@ -153,7 +153,7 @@ export async function handleApiRun(
if (!getSession(session_id)) {
const previewText = extractTextForPreview(input)
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
createSession({ id: session_id, profile, source: 'api_server', model, title: preview })
createSession({ id: session_id, profile, source: 'api_server', model, provider, title: preview })
}
addMessage({
session_id,
@@ -5,10 +5,11 @@
import type { Server, Socket } from 'socket.io'
import { getSystemPrompt } from '../../../lib/llm-prompt'
import { getSession, createSession, addMessage, updateSessionStats } from '../../../db/hermes/session-store'
import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store'
import { updateUsage } from '../../../db/hermes/usage-store'
import { logger, bridgeLogger } from '../../logger'
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
import { readConfigYaml } from '../../config-helpers'
import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks'
import { buildCompressedHistory } from './compression'
import { pushState, replaceState } from './compression'
@@ -29,10 +30,34 @@ import type { ChatMessage } from '../../../lib/context-compressor'
const BRIDGE_USAGE_FLUSH_DELAY_MS = 200
type RunModelGroup = { provider: string; models: string[] }
async function resolveDefaultModelConfig(): Promise<{ model: string; provider: string }> {
try {
const config = await readConfigYaml()
const modelConfig = config?.model
const model = typeof modelConfig === 'string'
? modelConfig.trim()
: String(modelConfig?.default || '').trim()
const provider = typeof modelConfig === 'object'
? String(modelConfig?.provider || '').trim()
: ''
return { model, provider }
} catch {
return { model: '', provider: '' }
}
}
function hasModelInGroups(groups: RunModelGroup[] | undefined, provider: string, model: string): boolean {
if (!groups?.length || !provider || !model) return false
const group = groups.find(item => item.provider === provider)
return Array.isArray(group?.models) && group.models.includes(model)
}
export async function handleBridgeRun(
nsp: ReturnType<Server['of']>,
socket: Socket,
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
data: { input: string | ContentBlock[]; session_id?: string; model?: string; provider?: string; model_groups?: RunModelGroup[]; instructions?: string; source?: string },
profile: string,
sessionMap: Map<string, SessionState>,
gatewayManager: any,
@@ -41,7 +66,7 @@ export async function handleBridgeRun(
loadSessionStateFromDbFn: (sid: string, sessionMap: Map<string, SessionState>) => Promise<SessionState>,
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
) {
const { input, session_id, model, instructions } = data
const { input, session_id, instructions } = data
if (!session_id) {
socket.emit('run.failed', { event: 'run.failed', error: 'session_id is required for cli source' })
return
@@ -51,6 +76,22 @@ export async function handleBridgeRun(
? `${getSystemPrompt()}\n${instructions}`
: getSystemPrompt()
const sessionRow = getSession(session_id)
const sessionModel = sessionRow?.model || ''
const sessionProvider = sessionRow?.provider || ''
const hasGroups = Array.isArray(data.model_groups) && data.model_groups.length > 0
const sessionModelAvailable = hasGroups && hasModelInGroups(data.model_groups, sessionProvider, sessionModel)
const shouldUseDefault = !sessionModel || !sessionProvider || !sessionModelAvailable
const defaultModelConfig = shouldUseDefault
? await resolveDefaultModelConfig()
: { model: '', provider: '' }
const resolvedModel = shouldUseDefault ? defaultModelConfig.model : sessionModel
const resolvedProvider = shouldUseDefault ? defaultModelConfig.provider : sessionProvider
if (sessionRow) {
const updates: { model?: string; provider?: string } = {}
if (resolvedModel && sessionRow.model !== resolvedModel) updates.model = resolvedModel
if (resolvedProvider && sessionRow.provider !== resolvedProvider) updates.provider = resolvedProvider
if (Object.keys(updates).length > 0) updateSession(session_id, updates)
}
if (sessionRow?.workspace) {
const workspaceCtx = `[Current working directory: ${sessionRow.workspace}]`
fullInstructions = `\n${workspaceCtx}\n${fullInstructions}`
@@ -93,7 +134,7 @@ export async function handleBridgeRun(
if (!getSession(session_id)) {
const previewText = extractTextForPreview(input)
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
createSession({ id: session_id, profile, source: 'cli', model, title: preview })
createSession({ id: session_id, profile, source: 'cli', model: resolvedModel, provider: resolvedProvider, title: preview })
}
addMessage({
session_id,
@@ -142,7 +183,11 @@ export async function handleBridgeRun(
bridgeHistory,
fullInstructions,
profile,
bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {},
{
...(bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {}),
...(resolvedModel ? { model: resolvedModel } : {}),
...(resolvedProvider ? { provider: resolvedProvider } : {}),
},
)
state.runId = started.run_id
bridgeLogger.info({
@@ -66,6 +66,8 @@ export class ChatRunSocket {
session_id?: string
model?: string
instructions?: string
provider?: string
model_groups?: Array<{ provider: string; models: string[] }>
queue_id?: string
source?: string
}) => {
@@ -102,6 +104,8 @@ export class ChatRunSocket {
queue_id: data.queue_id || `queue_${Date.now().toString(36)}_${Math.random().toString(36).slice(2, 8)}`,
input: data.input,
model: data.model,
provider: data.provider,
model_groups: data.model_groups,
instructions: data.instructions,
profile: currentProfile(),
source,
@@ -191,7 +195,15 @@ export class ChatRunSocket {
private async handleRun(
socket: Socket,
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
data: {
input: string | ContentBlock[]
session_id?: string
model?: string
provider?: string
model_groups?: Array<{ provider: string; models: string[] }>
instructions?: string
source?: string
},
profile: string,
skipUserMessage = false,
) {
@@ -273,6 +285,8 @@ export class ChatRunSocket {
input: next.input,
session_id: sessionId,
model: next.model,
provider: next.provider,
model_groups: next.model_groups,
instructions: next.instructions,
source: next.source,
}, next.profile || fallbackProfile, true)
@@ -29,6 +29,8 @@ export interface QueuedRun {
queue_id: string
input: string | ContentBlock[]
model?: string
provider?: string
model_groups?: Array<{ provider: string; models: string[] }>
instructions?: string
profile: string
source?: ChatRunSource