Add session-level bridge model settings (#811)

This commit is contained in:
ekko
2026-05-17 12:20:53 +08:00
committed by GitHub
parent fa035f348e
commit 5e8f8bd4a1
35 changed files with 697 additions and 60 deletions
@@ -41,6 +41,7 @@ export async function listConversations(ctx: any) {
id: s.id,
source: s.source,
model: s.model,
provider: s.provider,
title: s.title,
started_at: s.started_at,
ended_at: s.ended_at,
@@ -267,6 +268,28 @@ export async function setWorkspace(ctx: any) {
ctx.body = { ok: true }
}
export async function setModel(ctx: any) {
const { model, provider } = ctx.request.body as { model?: string; provider?: string }
if (!model || typeof model !== 'string') {
ctx.status = 400
ctx.body = { error: 'model is required' }
return
}
if (provider !== undefined && provider !== null && typeof provider !== 'string') {
ctx.status = 400
ctx.body = { error: 'provider must be a string' }
return
}
const { updateSession, getSession, createSession } = await import('../../db/hermes/session-store')
const { getActiveProfileName } = await import('../../services/hermes/hermes-profile')
const id = ctx.params.id
if (!getSession(id)) {
createSession({ id, profile: getActiveProfileName(), title: '' })
}
updateSession(id, { model: model.trim(), provider: (provider || '').trim() } as any)
ctx.body = { ok: true }
}
export async function contextLength(ctx: any) {
const profile = (ctx.query.profile as string) || undefined
ctx.body = { context_length: getModelContextLength(profile) }
+1
View File
@@ -34,6 +34,7 @@ export const SESSIONS_SCHEMA: Record<string, string> = {
source: 'TEXT NOT NULL DEFAULT \'api_server\'',
user_id: 'TEXT',
model: 'TEXT NOT NULL DEFAULT \'\'',
provider: 'TEXT NOT NULL DEFAULT \'\'',
title: 'TEXT',
started_at: 'INTEGER NOT NULL',
ended_at: 'INTEGER',
@@ -12,6 +12,7 @@ export interface HermesSessionRow {
source: string
user_id: string | null
model: string
provider: string
title: string | null
started_at: number
ended_at: number | null
@@ -85,6 +86,7 @@ function mapSessionRow(row: Record<string, unknown>): HermesSessionRow {
source: String(row.source || 'api_server'),
user_id: row.user_id != null ? String(row.user_id) : null,
model: String(row.model || ''),
provider: String(row.provider || ''),
title,
started_at: Number(row.started_at || 0),
ended_at: row.ended_at != null ? Number(row.ended_at) : null,
@@ -131,6 +133,7 @@ export function createSession(data: {
profile?: string
source?: string
model?: string
provider?: string
title?: string
workspace?: string
}): HermesSessionRow {
@@ -139,7 +142,7 @@ export function createSession(data: {
if (!isSqliteAvailable()) {
return {
id: data.id, profile: data.profile || 'default', source,
user_id: null, model: data.model || '', title: data.title || null,
user_id: null, model: data.model || '', provider: data.provider || '', title: data.title || null,
started_at: now, ended_at: null, end_reason: null,
message_count: 0, tool_call_count: 0,
input_tokens: 0, output_tokens: 0, cache_read_tokens: 0, cache_write_tokens: 0, reasoning_tokens: 0,
@@ -149,9 +152,9 @@ export function createSession(data: {
}
const db = getDb()!
db.prepare(
`INSERT INTO ${SESSIONS_TABLE} (id, profile, source, model, title, started_at, last_active, workspace)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
).run(data.id, data.profile || 'default', source, data.model || '', data.title || null, now, now, data.workspace || null)
`INSERT INTO ${SESSIONS_TABLE} (id, profile, source, model, provider, title, started_at, last_active, workspace)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
).run(data.id, data.profile || 'default', source, data.model || '', data.provider || '', data.title || null, now, now, data.workspace || null)
return getSession(data.id)!
}
@@ -21,4 +21,5 @@ sessionRoutes.delete('/api/hermes/sessions/:id', ctrl.remove)
sessionRoutes.post('/api/hermes/sessions/batch-delete', ctrl.batchRemove)
sessionRoutes.post('/api/hermes/sessions/:id/rename', ctrl.rename)
sessionRoutes.post('/api/hermes/sessions/:id/workspace', ctrl.setWorkspace)
sessionRoutes.post('/api/hermes/sessions/:id/model', ctrl.setModel)
sessionRoutes.get('/api/hermes/workspace/folders', ctrl.listWorkspaceFolders)
@@ -28,6 +28,13 @@ export interface AgentBridgeRequestOptions {
timeoutMs?: number
}
export interface AgentBridgeChatOptions {
force_compress?: boolean
storage_message?: AgentBridgeMessage
model?: string
provider?: string
}
export type AgentBridgeMessage =
| string
| Array<Record<string, unknown>>
@@ -306,7 +313,7 @@ export class AgentBridgeClient {
conversationHistory?: unknown[],
instructions?: string,
profile?: string,
options: { force_compress?: boolean; storage_message?: AgentBridgeMessage } = {},
options: AgentBridgeChatOptions = {},
): Promise<AgentBridgeChatStarted> {
return this.request<AgentBridgeChatStarted>({
action: 'chat',
@@ -316,6 +323,8 @@ export class AgentBridgeClient {
...(conversationHistory ? { conversation_history: conversationHistory } : {}),
...(instructions ? { instructions } : {}),
...(profile ? { profile } : {}),
...(options.model ? { model: options.model } : {}),
...(options.provider ? { provider: options.provider } : {}),
...(options.force_compress ? { force_compress: true } : {}),
})
}
@@ -491,12 +491,21 @@ class AgentPool:
self,
session_id: str,
profile: str | None = None,
model: str | None = None,
provider: str | None = None,
) -> AgentSession:
requested_model = str(model or "").strip()
requested_provider = str(provider or "").strip()
with self._lock:
existing = self._sessions.get(session_id)
if existing is not None:
# If profile changed, destroy old session and recreate
if profile and existing.config.get("profile") != profile:
config_changed = bool(
(profile and existing.config.get("profile") != profile)
or (requested_model and existing.config.get("model") != requested_model)
or (requested_provider and existing.config.get("provider") != requested_provider)
)
if config_changed:
if not existing.running:
self._destroy_session(session_id)
else:
@@ -512,8 +521,8 @@ class AgentPool:
with _profile_env(profile):
cfg = _load_cfg()
resolved_model = _resolve_model(cfg)
runtime = _resolve_runtime(resolved_model)
resolved_model = requested_model or _resolve_model(cfg)
runtime = _resolve_runtime(resolved_model, requested_provider or None)
agent_cfg = cfg.get("agent") or {}
prompt = str(agent_cfg.get("system_prompt", "") or "").strip() or None
@@ -949,8 +958,10 @@ class AgentPool:
conversation_history: list[dict[str, Any]] | None = None,
profile: str | None = None,
force_compress: bool = False,
model: str | None = None,
provider: str | None = None,
) -> RunRecord:
session = self.get_or_create(session_id, profile=profile)
session = self.get_or_create(session_id, profile=profile, model=model, provider=provider)
with session.lock:
if session.running:
raise RuntimeError(f"session {session_id} is already running")
@@ -1265,6 +1276,8 @@ class BridgeServer:
instructions = req.get("instructions") or req.get("system_message")
conversation_history = req.get("conversation_history")
profile = req.get("profile")
model = req.get("model")
provider = req.get("provider")
record = self.pool.start_chat(
session_id,
message,
@@ -1273,6 +1286,8 @@ class BridgeServer:
conversation_history,
profile,
bool(req.get("force_compress")),
model,
provider,
)
if req.get("wait"):
timeout = float(req.get("timeout", 0) or 0)
@@ -75,14 +75,14 @@ export async function loadSessionStateFromDb(sid: string, _sessionMap: Map<strin
export async function handleApiRun(
nsp: ReturnType<Server['of']>,
socket: Socket,
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
data: { input: string | ContentBlock[]; session_id?: string; model?: string; provider?: string; instructions?: string; source?: string },
profile: string,
sessionMap: Map<string, SessionState>,
gatewayManager: any,
skipUserMessage = false,
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
) {
const { input, session_id, model, instructions } = data
const { input, session_id, model, provider, instructions } = data
// Build full instructions with system prompt + workspace context
let fullInstructions = instructions
@@ -131,7 +131,7 @@ export async function handleApiRun(
if (!getSession(session_id)) {
const previewText = extractTextForPreview(input)
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
createSession({ id: session_id, profile, source: 'api_server', model, title: preview })
createSession({ id: session_id, profile, source: 'api_server', model, provider, title: preview })
}
addMessage({
@@ -153,7 +153,7 @@ export async function handleApiRun(
if (!getSession(session_id)) {
const previewText = extractTextForPreview(input)
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
createSession({ id: session_id, profile, source: 'api_server', model, title: preview })
createSession({ id: session_id, profile, source: 'api_server', model, provider, title: preview })
}
addMessage({
session_id,
@@ -5,10 +5,11 @@
import type { Server, Socket } from 'socket.io'
import { getSystemPrompt } from '../../../lib/llm-prompt'
import { getSession, createSession, addMessage, updateSessionStats } from '../../../db/hermes/session-store'
import { getSession, createSession, addMessage, updateSession, updateSessionStats } from '../../../db/hermes/session-store'
import { updateUsage } from '../../../db/hermes/usage-store'
import { logger, bridgeLogger } from '../../logger'
import { AgentBridgeClient, type AgentBridgeMessage, type AgentBridgeOutput } from '../agent-bridge'
import { readConfigYaml } from '../../config-helpers'
import { contentBlocksToString, convertContentBlocksForAgent, extractTextForPreview, isContentBlockArray } from './content-blocks'
import { buildCompressedHistory } from './compression'
import { pushState, replaceState } from './compression'
@@ -29,10 +30,34 @@ import type { ChatMessage } from '../../../lib/context-compressor'
const BRIDGE_USAGE_FLUSH_DELAY_MS = 200
type RunModelGroup = { provider: string; models: string[] }
async function resolveDefaultModelConfig(): Promise<{ model: string; provider: string }> {
try {
const config = await readConfigYaml()
const modelConfig = config?.model
const model = typeof modelConfig === 'string'
? modelConfig.trim()
: String(modelConfig?.default || '').trim()
const provider = typeof modelConfig === 'object'
? String(modelConfig?.provider || '').trim()
: ''
return { model, provider }
} catch {
return { model: '', provider: '' }
}
}
function hasModelInGroups(groups: RunModelGroup[] | undefined, provider: string, model: string): boolean {
if (!groups?.length || !provider || !model) return false
const group = groups.find(item => item.provider === provider)
return Array.isArray(group?.models) && group.models.includes(model)
}
export async function handleBridgeRun(
nsp: ReturnType<Server['of']>,
socket: Socket,
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
data: { input: string | ContentBlock[]; session_id?: string; model?: string; provider?: string; model_groups?: RunModelGroup[]; instructions?: string; source?: string },
profile: string,
sessionMap: Map<string, SessionState>,
gatewayManager: any,
@@ -41,7 +66,7 @@ export async function handleBridgeRun(
loadSessionStateFromDbFn: (sid: string, sessionMap: Map<string, SessionState>) => Promise<SessionState>,
dequeueNextQueuedRun: (socket: Socket, sessionId: string, fallbackProfile?: string) => void,
) {
const { input, session_id, model, instructions } = data
const { input, session_id, instructions } = data
if (!session_id) {
socket.emit('run.failed', { event: 'run.failed', error: 'session_id is required for cli source' })
return
@@ -51,6 +76,22 @@ export async function handleBridgeRun(
? `${getSystemPrompt()}\n${instructions}`
: getSystemPrompt()
const sessionRow = getSession(session_id)
const sessionModel = sessionRow?.model || ''
const sessionProvider = sessionRow?.provider || ''
const hasGroups = Array.isArray(data.model_groups) && data.model_groups.length > 0
const sessionModelAvailable = hasGroups && hasModelInGroups(data.model_groups, sessionProvider, sessionModel)
const shouldUseDefault = !sessionModel || !sessionProvider || !sessionModelAvailable
const defaultModelConfig = shouldUseDefault
? await resolveDefaultModelConfig()
: { model: '', provider: '' }
const resolvedModel = shouldUseDefault ? defaultModelConfig.model : sessionModel
const resolvedProvider = shouldUseDefault ? defaultModelConfig.provider : sessionProvider
if (sessionRow) {
const updates: { model?: string; provider?: string } = {}
if (resolvedModel && sessionRow.model !== resolvedModel) updates.model = resolvedModel
if (resolvedProvider && sessionRow.provider !== resolvedProvider) updates.provider = resolvedProvider
if (Object.keys(updates).length > 0) updateSession(session_id, updates)
}
if (sessionRow?.workspace) {
const workspaceCtx = `[Current working directory: ${sessionRow.workspace}]`
fullInstructions = `\n${workspaceCtx}\n${fullInstructions}`
@@ -93,7 +134,7 @@ export async function handleBridgeRun(
if (!getSession(session_id)) {
const previewText = extractTextForPreview(input)
const preview = previewText.replace(/[\r\n]/g, ' ').substring(0, 100)
createSession({ id: session_id, profile, source: 'cli', model, title: preview })
createSession({ id: session_id, profile, source: 'cli', model: resolvedModel, provider: resolvedProvider, title: preview })
}
addMessage({
session_id,
@@ -142,7 +183,11 @@ export async function handleBridgeRun(
bridgeHistory,
fullInstructions,
profile,
bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {},
{
...(bridgeStorageInput !== undefined ? { storage_message: bridgeStorageInput } : {}),
...(resolvedModel ? { model: resolvedModel } : {}),
...(resolvedProvider ? { provider: resolvedProvider } : {}),
},
)
state.runId = started.run_id
bridgeLogger.info({
@@ -66,6 +66,8 @@ export class ChatRunSocket {
session_id?: string
model?: string
instructions?: string
provider?: string
model_groups?: Array<{ provider: string; models: string[] }>
queue_id?: string
source?: string
}) => {
@@ -102,6 +104,8 @@ export class ChatRunSocket {
queue_id: data.queue_id || `queue_${Date.now().toString(36)}_${Math.random().toString(36).slice(2, 8)}`,
input: data.input,
model: data.model,
provider: data.provider,
model_groups: data.model_groups,
instructions: data.instructions,
profile: currentProfile(),
source,
@@ -191,7 +195,15 @@ export class ChatRunSocket {
private async handleRun(
socket: Socket,
data: { input: string | ContentBlock[]; session_id?: string; model?: string; instructions?: string; source?: string },
data: {
input: string | ContentBlock[]
session_id?: string
model?: string
provider?: string
model_groups?: Array<{ provider: string; models: string[] }>
instructions?: string
source?: string
},
profile: string,
skipUserMessage = false,
) {
@@ -273,6 +285,8 @@ export class ChatRunSocket {
input: next.input,
session_id: sessionId,
model: next.model,
provider: next.provider,
model_groups: next.model_groups,
instructions: next.instructions,
source: next.source,
}, next.profile || fallbackProfile, true)
@@ -29,6 +29,8 @@ export interface QueuedRun {
queue_id: string
input: string | ContentBlock[]
model?: string
provider?: string
model_groups?: Array<{ provider: string; models: string[] }>
instructions?: string
profile: string
source?: ChatRunSource