Fix custom provider context resolution (#468)

Co-authored-by: ekko <152005280+EKKOLearnAI@users.noreply.github.com>
This commit is contained in:
Zhicheng Han
2026-05-06 09:16:44 +02:00
committed by GitHub
parent 479e1feef6
commit 0209372a6d
2 changed files with 274 additions and 14 deletions
@@ -2,6 +2,7 @@ import { resolve, join } from 'path'
import { homedir } from 'os'
import { readFileSync, existsSync, statSync } from 'fs'
import yaml from 'js-yaml'
import { PROVIDER_PRESETS } from '../../shared/providers'
import { getDb } from '../../db'
import { MODEL_CONTEXT_TABLE } from '../../db/hermes/schemas'
@@ -25,6 +26,13 @@ interface ProviderEntry {
models?: Record<string, ModelEntry>
}
interface CustomProviderEntry {
name?: string
base_url?: string
model?: string
models?: Record<string, { context_length?: number }>
}
const MODEL_CACHE_PROVIDER_ALIASES: Record<string, string[]> = {
gemini: ['google'],
moonshot: ['moonshotai'],
@@ -107,24 +115,54 @@ function getConfigContextLength(config: any): number | null {
return val
}
function normalizeCustomProviderName(name: string): string {
return name.trim().toLowerCase().replace(/ /g, '-')
}
function normalizeBaseUrl(url: string): string {
return url.trim().toLowerCase().replace(/\/+$/, '')
}
function getModelBaseUrl(config: any): string | null {
const model = config?.model
if (!model || typeof model !== 'object') return null
return typeof model.base_url === 'string' ? model.base_url.trim() || null : null
}
function getCustomProviders(config: any): CustomProviderEntry[] {
return Array.isArray(config?.custom_providers) ? config.custom_providers as CustomProviderEntry[] : []
}
function resolveCustomProviderEntry(config: any, modelName: string, provider: string | null): CustomProviderEntry | null {
if (!provider || !provider.startsWith('custom')) return null
const providers = getCustomProviders(config)
if (provider !== 'custom') {
const suffix = normalizeCustomProviderName(provider.slice('custom:'.length))
return providers.find((cp) => normalizeCustomProviderName(String(cp?.name || '')) === suffix) || null
}
const modelBaseUrl = getModelBaseUrl(config)
if (modelBaseUrl) {
const normalizedBaseUrl = normalizeBaseUrl(modelBaseUrl)
const exactByBaseUrl = providers.find((cp) =>
normalizeBaseUrl(String(cp?.base_url || '')) === normalizedBaseUrl
&& String(cp?.model || '').trim() === modelName,
)
if (exactByBaseUrl) return exactByBaseUrl
}
const matchesByModel = providers.filter((cp) => String(cp?.model || '').trim() === modelName)
return matchesByModel.length === 1 ? matchesByModel[0] : null
}
/**
* Lookup context_length from custom_providers in config.yaml.
* - "custom:xxx" → strip prefix, match by name
* - "custom" → match by model name
*/
function lookupCustomProviderContextLength(config: any, modelName: string, provider: string | null): number | null {
const providers: any[] = Array.isArray(config?.custom_providers) ? config.custom_providers : []
if (!provider || !provider.startsWith('custom')) return null
let matched: any = null
if (provider === 'custom') {
matched = providers.find((cp: any) => cp.model === modelName)
} else {
const suffix = provider.slice('custom:'.length)
matched = providers.find((cp: any) => cp.name === suffix)
}
const matched = resolveCustomProviderEntry(config, modelName, provider)
if (!matched) return null
const models = matched.models
@@ -216,11 +254,68 @@ function lookupContextGloballyByModelName(data: Record<string, ProviderEntry>, m
return null
}
function lookupContextFromCache(modelName: string, provider: string | null): number | null {
function lookupUniqueContextGloballyByModelName(data: Record<string, ProviderEntry>, modelName: string): number | null {
const exactMatches: number[] = []
for (const prov of Object.values(data)) {
const context = getCachedContext(prov.models?.[modelName])
if (context) exactMatches.push(context)
if (exactMatches.length > 1) return null
}
if (exactMatches.length === 1) return exactMatches[0]
const lower = modelName.toLowerCase()
const ciMatches: number[] = []
for (const prov of Object.values(data)) {
const models = prov.models || {}
for (const [name, entry] of Object.entries(models)) {
if (name.toLowerCase() !== lower) continue
const context = getCachedContext(entry)
if (context) ciMatches.push(context)
break
}
if (ciMatches.length > 1) return null
}
return ciMatches[0] || null
}
function resolveCacheProviderFromBaseUrl(baseUrl: string | null): string | null {
if (!baseUrl) return null
const normalizedBaseUrl = normalizeBaseUrl(baseUrl)
const preset = PROVIDER_PRESETS.find((entry) => normalizeBaseUrl(entry.base_url) === normalizedBaseUrl)
return preset?.value || null
}
function resolveCustomCacheProvider(config: any, modelName: string, provider: string): string | null {
const customEntry = resolveCustomProviderEntry(config, modelName, provider)
const entryBaseUrl = typeof customEntry?.base_url === 'string' ? customEntry.base_url : null
const providerFromEntryBaseUrl = resolveCacheProviderFromBaseUrl(entryBaseUrl)
if (providerFromEntryBaseUrl) return providerFromEntryBaseUrl
return resolveCacheProviderFromBaseUrl(getModelBaseUrl(config))
}
function lookupContextFromCache(config: any, modelName: string, provider: string | null): number | null {
const data = loadModelsDevCache()
if (!data) return null
if (provider) {
if (provider === 'custom' || provider.startsWith('custom:')) {
const inferredProvider = resolveCustomCacheProvider(config, modelName, provider)
if (inferredProvider) {
const scoped = lookupContextInProvider(getProviderEntry(data, inferredProvider), modelName)
if (scoped) return scoped
return null
}
if (provider === 'custom') {
return lookupUniqueContextGloballyByModelName(data, modelName)
}
return null
}
return lookupContextInProvider(getProviderEntry(data, provider), modelName)
}
@@ -278,7 +373,7 @@ export function getModelContextLength(profile?: string): number {
if (customCtx && customCtx > 0) return customCtx
// 3. models_dev_cache.json
const cached = lookupContextFromCache(model, provider)
const cached = lookupContextFromCache(config, model, provider)
if (cached) return cached
// 4. Fallback
+165
View File
@@ -158,4 +158,169 @@ describe('getModelContextLength', () => {
expect(getModelContextLength()).toBe(1_000_000)
})
it('resolves provider: custom through model.base_url before falling back to the default context length', async () => {
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom\n base_url: https://api.deepseek.com\n`)
writeModelsCache({
deepseek: {
models: {
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(1_000_000)
})
it('resolves custom:name providers when the matched custom provider base_url points at a builtin provider', async () => {
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom:deepseek\n\ncustom_providers:\n - name: deepseek\n base_url: https://api.deepseek.com\n model: deepseek-v4-pro\n`)
writeModelsCache({
deepseek: {
models: {
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(1_000_000)
})
it('prefers the builtin provider inferred from a matched custom provider base_url over an arbitrary custom provider name', async () => {
writeConfig(`model:\n default: shared-model\n provider: custom:corp-proxy\n\ncustom_providers:\n - name: corp-proxy\n base_url: https://api.deepseek.com\n model: shared-model\n`)
writeModelsCache({
deepseek: {
models: {
'shared-model': { limit: { context: 1_000_000 } },
},
},
openai: {
models: {
'shared-model': { limit: { context: 400_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(1_000_000)
})
it('does not trust a stale custom:name provider hint without a matching custom provider entry', async () => {
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom:deepseek\n`)
writeModelsCache({
deepseek: {
models: {
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(200_000)
})
it('does not trust custom:name alone when the matched custom provider entry points at an unknown proxy url', async () => {
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom:deepseek\n\ncustom_providers:\n - name: deepseek\n base_url: https://proxy.example.com/v1\n model: deepseek-v4-pro\n`)
writeModelsCache({
deepseek: {
models: {
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(200_000)
})
it('does not fall through to a unique global match after a resolved custom:name provider misses in its scoped cache provider', async () => {
writeConfig(`model:\n default: gpt-5.5\n provider: custom:deepseek\n\ncustom_providers:\n - name: deepseek\n base_url: https://api.deepseek.com\n model: gpt-5.5\n`)
writeModelsCache({
openai: {
models: {
'gpt-5.5': { limit: { context: 400_000 } },
},
},
deepseek: {
models: {
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(200_000)
})
it('allows a unique global model-name fallback for unresolved custom providers', async () => {
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom\n base_url: https://proxy.example.com/v1\n`)
writeModelsCache({
deepseek: {
models: {
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(1_000_000)
})
it('still allows the unique global fallback when provider: custom matches a custom provider entry that cannot be mapped to a builtin cache provider', async () => {
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom\n\ncustom_providers:\n - name: corp-proxy\n base_url: https://proxy.example.com/v1\n model: deepseek-v4-pro\n`)
writeModelsCache({
deepseek: {
models: {
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(1_000_000)
})
it('keeps the unresolved custom-provider fallback strict to exact or case-insensitive model-name matches', async () => {
writeConfig(`model:\n default: gpt-5\n provider: custom\n base_url: https://proxy.example.com/v1\n`)
writeModelsCache({
vercel: {
models: {
'openai/gpt-5': { limit: { context: 1_000_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(200_000)
})
it('does not guess across multiple cache providers when a custom provider remains unresolved', async () => {
writeConfig(`model:\n default: shared-model\n provider: custom\n base_url: https://proxy.example.com/v1\n`)
writeModelsCache({
deepseek: {
models: {
'shared-model': { limit: { context: 1_000_000 } },
},
},
openai: {
models: {
'shared-model': { limit: { context: 400_000 } },
},
},
})
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength()).toBe(200_000)
})
})