Add session-level bridge model settings (#811)
This commit is contained in:
@@ -28,6 +28,13 @@ export interface AgentBridgeRequestOptions {
|
||||
timeoutMs?: number
|
||||
}
|
||||
|
||||
export interface AgentBridgeChatOptions {
|
||||
force_compress?: boolean
|
||||
storage_message?: AgentBridgeMessage
|
||||
model?: string
|
||||
provider?: string
|
||||
}
|
||||
|
||||
export type AgentBridgeMessage =
|
||||
| string
|
||||
| Array<Record<string, unknown>>
|
||||
@@ -306,7 +313,7 @@ export class AgentBridgeClient {
|
||||
conversationHistory?: unknown[],
|
||||
instructions?: string,
|
||||
profile?: string,
|
||||
options: { force_compress?: boolean; storage_message?: AgentBridgeMessage } = {},
|
||||
options: AgentBridgeChatOptions = {},
|
||||
): Promise<AgentBridgeChatStarted> {
|
||||
return this.request<AgentBridgeChatStarted>({
|
||||
action: 'chat',
|
||||
@@ -316,6 +323,8 @@ export class AgentBridgeClient {
|
||||
...(conversationHistory ? { conversation_history: conversationHistory } : {}),
|
||||
...(instructions ? { instructions } : {}),
|
||||
...(profile ? { profile } : {}),
|
||||
...(options.model ? { model: options.model } : {}),
|
||||
...(options.provider ? { provider: options.provider } : {}),
|
||||
...(options.force_compress ? { force_compress: true } : {}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -491,12 +491,21 @@ class AgentPool:
|
||||
self,
|
||||
session_id: str,
|
||||
profile: str | None = None,
|
||||
model: str | None = None,
|
||||
provider: str | None = None,
|
||||
) -> AgentSession:
|
||||
requested_model = str(model or "").strip()
|
||||
requested_provider = str(provider or "").strip()
|
||||
with self._lock:
|
||||
existing = self._sessions.get(session_id)
|
||||
if existing is not None:
|
||||
# If profile changed, destroy old session and recreate
|
||||
if profile and existing.config.get("profile") != profile:
|
||||
config_changed = bool(
|
||||
(profile and existing.config.get("profile") != profile)
|
||||
or (requested_model and existing.config.get("model") != requested_model)
|
||||
or (requested_provider and existing.config.get("provider") != requested_provider)
|
||||
)
|
||||
if config_changed:
|
||||
if not existing.running:
|
||||
self._destroy_session(session_id)
|
||||
else:
|
||||
@@ -512,8 +521,8 @@ class AgentPool:
|
||||
|
||||
with _profile_env(profile):
|
||||
cfg = _load_cfg()
|
||||
resolved_model = _resolve_model(cfg)
|
||||
runtime = _resolve_runtime(resolved_model)
|
||||
resolved_model = requested_model or _resolve_model(cfg)
|
||||
runtime = _resolve_runtime(resolved_model, requested_provider or None)
|
||||
agent_cfg = cfg.get("agent") or {}
|
||||
prompt = str(agent_cfg.get("system_prompt", "") or "").strip() or None
|
||||
|
||||
@@ -949,8 +958,10 @@ class AgentPool:
|
||||
conversation_history: list[dict[str, Any]] | None = None,
|
||||
profile: str | None = None,
|
||||
force_compress: bool = False,
|
||||
model: str | None = None,
|
||||
provider: str | None = None,
|
||||
) -> RunRecord:
|
||||
session = self.get_or_create(session_id, profile=profile)
|
||||
session = self.get_or_create(session_id, profile=profile, model=model, provider=provider)
|
||||
with session.lock:
|
||||
if session.running:
|
||||
raise RuntimeError(f"session {session_id} is already running")
|
||||
@@ -1265,6 +1276,8 @@ class BridgeServer:
|
||||
instructions = req.get("instructions") or req.get("system_message")
|
||||
conversation_history = req.get("conversation_history")
|
||||
profile = req.get("profile")
|
||||
model = req.get("model")
|
||||
provider = req.get("provider")
|
||||
record = self.pool.start_chat(
|
||||
session_id,
|
||||
message,
|
||||
@@ -1273,6 +1286,8 @@ class BridgeServer:
|
||||
conversation_history,
|
||||
profile,
|
||||
bool(req.get("force_compress")),
|
||||
model,
|
||||
provider,
|
||||
)
|
||||
if req.get("wait"):
|
||||
timeout = float(req.get("timeout", 0) or 0)
|
||||
|
||||
@@ -75,14 +75,14 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
|
||||
export async function handleApiRun(
|
||||
nsp: ReturnType<Server['of']>,
|
||||
socket: Socket,
|
||||
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
|
||||
data: { input: string | ContentBlock[]; session_id?: string; model?: string; provider?: string; instructions?: string; source?: string },
|
||||
profile: string,
|
||||
sessionMap: Map<string, SessionState>,
|
||||
gatewayManager: any,
|
||||
skipUserMessage = false,
|
||||
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
|
||||
) {
|
||||
const { input, session_id, model, instructions } = data
|
||||
const { input, session_id, model, provider, instructions } = data
|
||||
|
||||
// Build full instructions with system prompt + workspace context
|
||||
let fullInstructions = instructions
|
||||
@@ -131,7 +131,7 @@ export async function handleApiRun(
|
||||
if (!getSession(session_id)) {
|
||||
const previewText = extractTextForPreview(input)
|
||||
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
|
||||
createSession({ id: session_id, profile, source: 'api_server', model, title: preview })
|
||||
createSession({ id: session_id, profile, source: 'api_server', model, provider, title: preview })
|
||||
}
|
||||
|
||||
addMessage({
|
||||
@@ -153,7 +153,7 @@ export async function handleApiRun(
|
||||
if (!getSession(session_id)) {
|
||||
const previewText = extractTextForPreview(input)
|
||||
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
|
||||
createSession({ id: session_id, profile, source: 'api_server', model, title: preview })
|
||||
createSession({ id: session_id, profile, source: 'api_server', model, provider, title: preview })
|
||||
}
|
||||
addMessage({
|
||||
session_id,
|
||||
|
||||
@@ -5,10 +5,11 @@
|
||||
|
||||
import type { Server, Socket } from 'socket.io'
|
||||
import { getSystemPrompt } from '../../../lib/llm-prompt'
|
||||
import { getSession, createSession, addMessage, updateSessionStats } from '../../../db/hermes/session-store'
|
||||
import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store'
|
||||
import { updateUsage } from '../../../db/hermes/usage-store'
|
||||
import { logger, bridgeLogger } from '../../logger'
|
||||
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
|
||||
import { readConfigYaml } from '../../config-helpers'
|
||||
import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks'
|
||||
import { buildCompressedHistory } from './compression'
|
||||
import { pushState, replaceState } from './compression'
|
||||
@@ -29,10 +30,34 @@ import type { ChatMessage } from '../../../lib/context-compressor'
|
||||
|
||||
const BRIDGE_USAGE_FLUSH_DELAY_MS = 200
|
||||
|
||||
type RunModelGroup = { provider: string; models: string[] }
|
||||
|
||||
async function resolveDefaultModelConfig(): Promise<{ model: string; provider: string }> {
|
||||
try {
|
||||
const config = await readConfigYaml()
|
||||
const modelConfig = config?.model
|
||||
const model = typeof modelConfig === 'string'
|
||||
? modelConfig.trim()
|
||||
: String(modelConfig?.default || '').trim()
|
||||
const provider = typeof modelConfig === 'object'
|
||||
? String(modelConfig?.provider || '').trim()
|
||||
: ''
|
||||
return { model, provider }
|
||||
} catch {
|
||||
return { model: '', provider: '' }
|
||||
}
|
||||
}
|
||||
|
||||
function hasModelInGroups(groups: RunModelGroup[] | undefined, provider: string, model: string): boolean {
|
||||
if (!groups?.length || !provider || !model) return false
|
||||
const group = groups.find(item => item.provider === provider)
|
||||
return Array.isArray(group?.models) && group.models.includes(model)
|
||||
}
|
||||
|
||||
export async function handleBridgeRun(
|
||||
nsp: ReturnType<Server['of']>,
|
||||
socket: Socket,
|
||||
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
|
||||
data: { input: string | ContentBlock[]; session_id?: string; model?: string; provider?: string; model_groups?: RunModelGroup[]; instructions?: string; source?: string },
|
||||
profile: string,
|
||||
sessionMap: Map<string, SessionState>,
|
||||
gatewayManager: any,
|
||||
@@ -41,7 +66,7 @@ export async function handleBridgeRun(
|
||||
loadSessionStateFromDbFn: (sid: string, sessionMap: Map<string, SessionState>) => Promise<SessionState>,
|
||||
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
|
||||
) {
|
||||
const { input, session_id, model, instructions } = data
|
||||
const { input, session_id, instructions } = data
|
||||
if (!session_id) {
|
||||
socket.emit('run.failed', { event: 'run.failed', error: 'session_id is required for cli source' })
|
||||
return
|
||||
@@ -51,6 +76,22 @@ export async function handleBridgeRun(
|
||||
? `${getSystemPrompt()}\n${instructions}`
|
||||
: getSystemPrompt()
|
||||
const sessionRow = getSession(session_id)
|
||||
const sessionModel = sessionRow?.model || ''
|
||||
const sessionProvider = sessionRow?.provider || ''
|
||||
const hasGroups = Array.isArray(data.model_groups) && data.model_groups.length > 0
|
||||
const sessionModelAvailable = hasGroups && hasModelInGroups(data.model_groups, sessionProvider, sessionModel)
|
||||
const shouldUseDefault = !sessionModel || !sessionProvider || !sessionModelAvailable
|
||||
const defaultModelConfig = shouldUseDefault
|
||||
? await resolveDefaultModelConfig()
|
||||
: { model: '', provider: '' }
|
||||
const resolvedModel = shouldUseDefault ? defaultModelConfig.model : sessionModel
|
||||
const resolvedProvider = shouldUseDefault ? defaultModelConfig.provider : sessionProvider
|
||||
if (sessionRow) {
|
||||
const updates: { model?: string; provider?: string } = {}
|
||||
if (resolvedModel && sessionRow.model !== resolvedModel) updates.model = resolvedModel
|
||||
if (resolvedProvider && sessionRow.provider !== resolvedProvider) updates.provider = resolvedProvider
|
||||
if (Object.keys(updates).length > 0) updateSession(session_id, updates)
|
||||
}
|
||||
if (sessionRow?.workspace) {
|
||||
const workspaceCtx = `[Current working directory: ${sessionRow.workspace}]`
|
||||
fullInstructions = `\n${workspaceCtx}\n${fullInstructions}`
|
||||
@@ -93,7 +134,7 @@ export async function handleBridgeRun(
|
||||
if (!getSession(session_id)) {
|
||||
const previewText = extractTextForPreview(input)
|
||||
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
|
||||
createSession({ id: session_id, profile, source: 'cli', model, title: preview })
|
||||
createSession({ id: session_id, profile, source: 'cli', model: resolvedModel, provider: resolvedProvider, title: preview })
|
||||
}
|
||||
addMessage({
|
||||
session_id,
|
||||
@@ -142,7 +183,11 @@ export async function handleBridgeRun(
|
||||
bridgeHistory,
|
||||
fullInstructions,
|
||||
profile,
|
||||
bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {},
|
||||
{
|
||||
...(bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {}),
|
||||
...(resolvedModel ? { model: resolvedModel } : {}),
|
||||
...(resolvedProvider ? { provider: resolvedProvider } : {}),
|
||||
},
|
||||
)
|
||||
state.runId = started.run_id
|
||||
bridgeLogger.info({
|
||||
|
||||
@@ -66,6 +66,8 @@ export class ChatRunSocket {
|
||||
session_id?: string
|
||||
model?: string
|
||||
instructions?: string
|
||||
provider?: string
|
||||
model_groups?: Array<{ provider: string; models: string[] }>
|
||||
queue_id?: string
|
||||
source?: string
|
||||
}) => {
|
||||
@@ -102,6 +104,8 @@ export class ChatRunSocket {
|
||||
queue_id: data.queue_id || `queue_${Date.now().toString(36)}_${Math.random().toString(36).slice(2, 8)}`,
|
||||
input: data.input,
|
||||
model: data.model,
|
||||
provider: data.provider,
|
||||
model_groups: data.model_groups,
|
||||
instructions: data.instructions,
|
||||
profile: currentProfile(),
|
||||
source,
|
||||
@@ -191,7 +195,15 @@ export class ChatRunSocket {
|
||||
|
||||
private async handleRun(
|
||||
socket: Socket,
|
||||
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
|
||||
data: {
|
||||
input: string | ContentBlock[]
|
||||
session_id?: string
|
||||
model?: string
|
||||
provider?: string
|
||||
model_groups?: Array<{ provider: string; models: string[] }>
|
||||
instructions?: string
|
||||
source?: string
|
||||
},
|
||||
profile: string,
|
||||
skipUserMessage = false,
|
||||
) {
|
||||
@@ -273,6 +285,8 @@ export class ChatRunSocket {
|
||||
input: next.input,
|
||||
session_id: sessionId,
|
||||
model: next.model,
|
||||
provider: next.provider,
|
||||
model_groups: next.model_groups,
|
||||
instructions: next.instructions,
|
||||
source: next.source,
|
||||
}, next.profile || fallbackProfile, true)
|
||||
|
||||
@@ -29,6 +29,8 @@ export interface QueuedRun {
|
||||
queue_id: string
|
||||
input: string | ContentBlock[]
|
||||
model?: string
|
||||
provider?: string
|
||||
model_groups?: Array<{ provider: string; models: string[] }>
|
||||
instructions?: string
|
||||
profile: string
|
||||
source?: ChatRunSource
|
||||
|
||||
Reference in New Issue
Block a user