From 0209372a6d99fe28b8b4c494d427ae5b25692da3 Mon Sep 17 00:00:00 2001 From: Zhicheng Han <43314240+hanzckernel@users.noreply.github.com> Date: Wed, 6 May 2026 09:16:44 +0200 Subject: [PATCH] Fix custom provider context resolution (#468) Co-authored-by: ekko <152005280+EKKOLearnAI@users.noreply.github.com> --- .../src/services/hermes/model-context.ts | 123 +++++++++++-- tests/server/model-context.test.ts | 165 ++++++++++++++++++ 2 files changed, 274 insertions(+), 14 deletions(-) diff --git a/packages/server/src/services/hermes/model-context.ts b/packages/server/src/services/hermes/model-context.ts index 1b103ba..303421e 100644 --- a/packages/server/src/services/hermes/model-context.ts +++ b/packages/server/src/services/hermes/model-context.ts @@ -2,6 +2,7 @@ import { resolve, join } from 'path' import { homedir } from 'os' import { readFileSync, existsSync, statSync } from 'fs' import yaml from 'js-yaml' +import { PROVIDER_PRESETS } from '../../shared/providers' import { getDb } from '../../db' import { MODEL_CONTEXT_TABLE } from '../../db/hermes/schemas' @@ -25,6 +26,13 @@ interface ProviderEntry { models?: Record } +interface CustomProviderEntry { + name?: string + base_url?: string + model?: string + models?: Record +} + const MODEL_CACHE_PROVIDER_ALIASES: Record = { gemini: ['google'], moonshot: ['moonshotai'], @@ -107,24 +115,54 @@ function getConfigContextLength(config: any): number | null { return val } +function normalizeCustomProviderName(name: string): string { + return name.trim().toLowerCase().replace(/ /g, '-') +} + +function normalizeBaseUrl(url: string): string { + return url.trim().toLowerCase().replace(/\/+$/, '') +} + +function getModelBaseUrl(config: any): string | null { + const model = config?.model + if (!model || typeof model !== 'object') return null + return typeof model.base_url === 'string' ? model.base_url.trim() || null : null +} + +function getCustomProviders(config: any): CustomProviderEntry[] { + return Array.isArray(config?.custom_providers) ? config.custom_providers as CustomProviderEntry[] : [] +} + +function resolveCustomProviderEntry(config: any, modelName: string, provider: string | null): CustomProviderEntry | null { + if (!provider || !provider.startsWith('custom')) return null + + const providers = getCustomProviders(config) + if (provider !== 'custom') { + const suffix = normalizeCustomProviderName(provider.slice('custom:'.length)) + return providers.find((cp) => normalizeCustomProviderName(String(cp?.name || '')) === suffix) || null + } + + const modelBaseUrl = getModelBaseUrl(config) + if (modelBaseUrl) { + const normalizedBaseUrl = normalizeBaseUrl(modelBaseUrl) + const exactByBaseUrl = providers.find((cp) => + normalizeBaseUrl(String(cp?.base_url || '')) === normalizedBaseUrl + && String(cp?.model || '').trim() === modelName, + ) + if (exactByBaseUrl) return exactByBaseUrl + } + + const matchesByModel = providers.filter((cp) => String(cp?.model || '').trim() === modelName) + return matchesByModel.length === 1 ? matchesByModel[0] : null +} + /** * Lookup context_length from custom_providers in config.yaml. * - "custom:xxx" → strip prefix, match by name * - "custom" → match by model name */ function lookupCustomProviderContextLength(config: any, modelName: string, provider: string | null): number | null { - const providers: any[] = Array.isArray(config?.custom_providers) ? config.custom_providers : [] - if (!provider || !provider.startsWith('custom')) return null - - let matched: any = null - - if (provider === 'custom') { - matched = providers.find((cp: any) => cp.model === modelName) - } else { - const suffix = provider.slice('custom:'.length) - matched = providers.find((cp: any) => cp.name === suffix) - } - + const matched = resolveCustomProviderEntry(config, modelName, provider) if (!matched) return null const models = matched.models @@ -216,11 +254,68 @@ function lookupContextGloballyByModelName(data: Record, m return null } -function lookupContextFromCache(modelName: string, provider: string | null): number | null { +function lookupUniqueContextGloballyByModelName(data: Record, modelName: string): number | null { + const exactMatches: number[] = [] + for (const prov of Object.values(data)) { + const context = getCachedContext(prov.models?.[modelName]) + if (context) exactMatches.push(context) + if (exactMatches.length > 1) return null + } + if (exactMatches.length === 1) return exactMatches[0] + + const lower = modelName.toLowerCase() + const ciMatches: number[] = [] + for (const prov of Object.values(data)) { + const models = prov.models || {} + for (const [name, entry] of Object.entries(models)) { + if (name.toLowerCase() !== lower) continue + const context = getCachedContext(entry) + if (context) ciMatches.push(context) + break + } + if (ciMatches.length > 1) return null + } + + return ciMatches[0] || null +} + +function resolveCacheProviderFromBaseUrl(baseUrl: string | null): string | null { + if (!baseUrl) return null + const normalizedBaseUrl = normalizeBaseUrl(baseUrl) + const preset = PROVIDER_PRESETS.find((entry) => normalizeBaseUrl(entry.base_url) === normalizedBaseUrl) + return preset?.value || null +} + +function resolveCustomCacheProvider(config: any, modelName: string, provider: string): string | null { + const customEntry = resolveCustomProviderEntry(config, modelName, provider) + const entryBaseUrl = typeof customEntry?.base_url === 'string' ? customEntry.base_url : null + const providerFromEntryBaseUrl = resolveCacheProviderFromBaseUrl(entryBaseUrl) + if (providerFromEntryBaseUrl) return providerFromEntryBaseUrl + + return resolveCacheProviderFromBaseUrl(getModelBaseUrl(config)) +} + +function lookupContextFromCache(config: any, modelName: string, provider: string | null): number | null { const data = loadModelsDevCache() if (!data) return null if (provider) { + if (provider === 'custom' || provider.startsWith('custom:')) { + const inferredProvider = resolveCustomCacheProvider(config, modelName, provider) + + if (inferredProvider) { + const scoped = lookupContextInProvider(getProviderEntry(data, inferredProvider), modelName) + if (scoped) return scoped + return null + } + + if (provider === 'custom') { + return lookupUniqueContextGloballyByModelName(data, modelName) + } + + return null + } + return lookupContextInProvider(getProviderEntry(data, provider), modelName) } @@ -278,7 +373,7 @@ export function getModelContextLength(profile?: string): number { if (customCtx && customCtx > 0) return customCtx // 3. models_dev_cache.json - const cached = lookupContextFromCache(model, provider) + const cached = lookupContextFromCache(config, model, provider) if (cached) return cached // 4. Fallback diff --git a/tests/server/model-context.test.ts b/tests/server/model-context.test.ts index 09506fa..1689203 100644 --- a/tests/server/model-context.test.ts +++ b/tests/server/model-context.test.ts @@ -158,4 +158,169 @@ describe('getModelContextLength', () => { expect(getModelContextLength()).toBe(1_000_000) }) + + it('resolves provider: custom through model.base_url before falling back to the default context length', async () => { + writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom\n base_url: https://api.deepseek.com\n`) + writeModelsCache({ + deepseek: { + models: { + 'deepseek-v4-pro': { limit: { context: 1_000_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(1_000_000) + }) + + it('resolves custom:name providers when the matched custom provider base_url points at a builtin provider', async () => { + writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom:deepseek\n\ncustom_providers:\n - name: deepseek\n base_url: https://api.deepseek.com\n model: deepseek-v4-pro\n`) + writeModelsCache({ + deepseek: { + models: { + 'deepseek-v4-pro': { limit: { context: 1_000_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(1_000_000) + }) + + it('prefers the builtin provider inferred from a matched custom provider base_url over an arbitrary custom provider name', async () => { + writeConfig(`model:\n default: shared-model\n provider: custom:corp-proxy\n\ncustom_providers:\n - name: corp-proxy\n base_url: https://api.deepseek.com\n model: shared-model\n`) + writeModelsCache({ + deepseek: { + models: { + 'shared-model': { limit: { context: 1_000_000 } }, + }, + }, + openai: { + models: { + 'shared-model': { limit: { context: 400_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(1_000_000) + }) + + it('does not trust a stale custom:name provider hint without a matching custom provider entry', async () => { + writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom:deepseek\n`) + writeModelsCache({ + deepseek: { + models: { + 'deepseek-v4-pro': { limit: { context: 1_000_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(200_000) + }) + + it('does not trust custom:name alone when the matched custom provider entry points at an unknown proxy url', async () => { + writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom:deepseek\n\ncustom_providers:\n - name: deepseek\n base_url: https://proxy.example.com/v1\n model: deepseek-v4-pro\n`) + writeModelsCache({ + deepseek: { + models: { + 'deepseek-v4-pro': { limit: { context: 1_000_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(200_000) + }) + + it('does not fall through to a unique global match after a resolved custom:name provider misses in its scoped cache provider', async () => { + writeConfig(`model:\n default: gpt-5.5\n provider: custom:deepseek\n\ncustom_providers:\n - name: deepseek\n base_url: https://api.deepseek.com\n model: gpt-5.5\n`) + writeModelsCache({ + openai: { + models: { + 'gpt-5.5': { limit: { context: 400_000 } }, + }, + }, + deepseek: { + models: { + 'deepseek-v4-pro': { limit: { context: 1_000_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(200_000) + }) + + it('allows a unique global model-name fallback for unresolved custom providers', async () => { + writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom\n base_url: https://proxy.example.com/v1\n`) + writeModelsCache({ + deepseek: { + models: { + 'deepseek-v4-pro': { limit: { context: 1_000_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(1_000_000) + }) + + it('still allows the unique global fallback when provider: custom matches a custom provider entry that cannot be mapped to a builtin cache provider', async () => { + writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom\n\ncustom_providers:\n - name: corp-proxy\n base_url: https://proxy.example.com/v1\n model: deepseek-v4-pro\n`) + writeModelsCache({ + deepseek: { + models: { + 'deepseek-v4-pro': { limit: { context: 1_000_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(1_000_000) + }) + + it('keeps the unresolved custom-provider fallback strict to exact or case-insensitive model-name matches', async () => { + writeConfig(`model:\n default: gpt-5\n provider: custom\n base_url: https://proxy.example.com/v1\n`) + writeModelsCache({ + vercel: { + models: { + 'openai/gpt-5': { limit: { context: 1_000_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(200_000) + }) + + it('does not guess across multiple cache providers when a custom provider remains unresolved', async () => { + writeConfig(`model:\n default: shared-model\n provider: custom\n base_url: https://proxy.example.com/v1\n`) + writeModelsCache({ + deepseek: { + models: { + 'shared-model': { limit: { context: 1_000_000 } }, + }, + }, + openai: { + models: { + 'shared-model': { limit: { context: 400_000 } }, + }, + }, + }) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength()).toBe(200_000) + }) })