Fix custom provider context resolution (#468)
Co-authored-by: ekko <152005280+EKKOLearnAI@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@ import { resolve, join } from 'path'
|
||||
import { homedir } from 'os'
|
||||
import { readFileSync, existsSync, statSync } from 'fs'
|
||||
import yaml from 'js-yaml'
|
||||
import { PROVIDER_PRESETS } from '../../shared/providers'
|
||||
import { getDb } from '../../db'
|
||||
import { MODEL_CONTEXT_TABLE } from '../../db/hermes/schemas'
|
||||
|
||||
@@ -25,6 +26,13 @@ interface ProviderEntry {
|
||||
models?: Record<string, ModelEntry>
|
||||
}
|
||||
|
||||
interface CustomProviderEntry {
|
||||
name?: string
|
||||
base_url?: string
|
||||
model?: string
|
||||
models?: Record<string, { context_length?: number }>
|
||||
}
|
||||
|
||||
const MODEL_CACHE_PROVIDER_ALIASES: Record<string, string[]> = {
|
||||
gemini: ['google'],
|
||||
moonshot: ['moonshotai'],
|
||||
@@ -107,24 +115,54 @@ function getConfigContextLength(config: any): number | null {
|
||||
return val
|
||||
}
|
||||
|
||||
function normalizeCustomProviderName(name: string): string {
|
||||
return name.trim().toLowerCase().replace(/ /g, '-')
|
||||
}
|
||||
|
||||
function normalizeBaseUrl(url: string): string {
|
||||
return url.trim().toLowerCase().replace(/\/+$/, '')
|
||||
}
|
||||
|
||||
function getModelBaseUrl(config: any): string | null {
|
||||
const model = config?.model
|
||||
if (!model || typeof model !== 'object') return null
|
||||
return typeof model.base_url === 'string' ? model.base_url.trim() || null : null
|
||||
}
|
||||
|
||||
function getCustomProviders(config: any): CustomProviderEntry[] {
|
||||
return Array.isArray(config?.custom_providers) ? config.custom_providers as CustomProviderEntry[] : []
|
||||
}
|
||||
|
||||
function resolveCustomProviderEntry(config: any, modelName: string, provider: string | null): CustomProviderEntry | null {
|
||||
if (!provider || !provider.startsWith('custom')) return null
|
||||
|
||||
const providers = getCustomProviders(config)
|
||||
if (provider !== 'custom') {
|
||||
const suffix = normalizeCustomProviderName(provider.slice('custom:'.length))
|
||||
return providers.find((cp) => normalizeCustomProviderName(String(cp?.name || '')) === suffix) || null
|
||||
}
|
||||
|
||||
const modelBaseUrl = getModelBaseUrl(config)
|
||||
if (modelBaseUrl) {
|
||||
const normalizedBaseUrl = normalizeBaseUrl(modelBaseUrl)
|
||||
const exactByBaseUrl = providers.find((cp) =>
|
||||
normalizeBaseUrl(String(cp?.base_url || '')) === normalizedBaseUrl
|
||||
&& String(cp?.model || '').trim() === modelName,
|
||||
)
|
||||
if (exactByBaseUrl) return exactByBaseUrl
|
||||
}
|
||||
|
||||
const matchesByModel = providers.filter((cp) => String(cp?.model || '').trim() === modelName)
|
||||
return matchesByModel.length === 1 ? matchesByModel[0] : null
|
||||
}
|
||||
|
||||
/**
|
||||
* Lookup context_length from custom_providers in config.yaml.
|
||||
* - "custom:xxx" → strip prefix, match by name
|
||||
* - "custom" → match by model name
|
||||
*/
|
||||
function lookupCustomProviderContextLength(config: any, modelName: string, provider: string | null): number | null {
|
||||
const providers: any[] = Array.isArray(config?.custom_providers) ? config.custom_providers : []
|
||||
if (!provider || !provider.startsWith('custom')) return null
|
||||
|
||||
let matched: any = null
|
||||
|
||||
if (provider === 'custom') {
|
||||
matched = providers.find((cp: any) => cp.model === modelName)
|
||||
} else {
|
||||
const suffix = provider.slice('custom:'.length)
|
||||
matched = providers.find((cp: any) => cp.name === suffix)
|
||||
}
|
||||
|
||||
const matched = resolveCustomProviderEntry(config, modelName, provider)
|
||||
if (!matched) return null
|
||||
|
||||
const models = matched.models
|
||||
@@ -216,11 +254,68 @@ function lookupContextGloballyByModelName(data: Record<string, ProviderEntry>, m
|
||||
return null
|
||||
}
|
||||
|
||||
function lookupContextFromCache(modelName: string, provider: string | null): number | null {
|
||||
function lookupUniqueContextGloballyByModelName(data: Record<string, ProviderEntry>, modelName: string): number | null {
|
||||
const exactMatches: number[] = []
|
||||
for (const prov of Object.values(data)) {
|
||||
const context = getCachedContext(prov.models?.[modelName])
|
||||
if (context) exactMatches.push(context)
|
||||
if (exactMatches.length > 1) return null
|
||||
}
|
||||
if (exactMatches.length === 1) return exactMatches[0]
|
||||
|
||||
const lower = modelName.toLowerCase()
|
||||
const ciMatches: number[] = []
|
||||
for (const prov of Object.values(data)) {
|
||||
const models = prov.models || {}
|
||||
for (const [name, entry] of Object.entries(models)) {
|
||||
if (name.toLowerCase() !== lower) continue
|
||||
const context = getCachedContext(entry)
|
||||
if (context) ciMatches.push(context)
|
||||
break
|
||||
}
|
||||
if (ciMatches.length > 1) return null
|
||||
}
|
||||
|
||||
return ciMatches[0] || null
|
||||
}
|
||||
|
||||
function resolveCacheProviderFromBaseUrl(baseUrl: string | null): string | null {
|
||||
if (!baseUrl) return null
|
||||
const normalizedBaseUrl = normalizeBaseUrl(baseUrl)
|
||||
const preset = PROVIDER_PRESETS.find((entry) => normalizeBaseUrl(entry.base_url) === normalizedBaseUrl)
|
||||
return preset?.value || null
|
||||
}
|
||||
|
||||
function resolveCustomCacheProvider(config: any, modelName: string, provider: string): string | null {
|
||||
const customEntry = resolveCustomProviderEntry(config, modelName, provider)
|
||||
const entryBaseUrl = typeof customEntry?.base_url === 'string' ? customEntry.base_url : null
|
||||
const providerFromEntryBaseUrl = resolveCacheProviderFromBaseUrl(entryBaseUrl)
|
||||
if (providerFromEntryBaseUrl) return providerFromEntryBaseUrl
|
||||
|
||||
return resolveCacheProviderFromBaseUrl(getModelBaseUrl(config))
|
||||
}
|
||||
|
||||
function lookupContextFromCache(config: any, modelName: string, provider: string | null): number | null {
|
||||
const data = loadModelsDevCache()
|
||||
if (!data) return null
|
||||
|
||||
if (provider) {
|
||||
if (provider === 'custom' || provider.startsWith('custom:')) {
|
||||
const inferredProvider = resolveCustomCacheProvider(config, modelName, provider)
|
||||
|
||||
if (inferredProvider) {
|
||||
const scoped = lookupContextInProvider(getProviderEntry(data, inferredProvider), modelName)
|
||||
if (scoped) return scoped
|
||||
return null
|
||||
}
|
||||
|
||||
if (provider === 'custom') {
|
||||
return lookupUniqueContextGloballyByModelName(data, modelName)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
return lookupContextInProvider(getProviderEntry(data, provider), modelName)
|
||||
}
|
||||
|
||||
@@ -278,7 +373,7 @@ export function getModelContextLength(profile?: string): number {
|
||||
if (customCtx && customCtx > 0) return customCtx
|
||||
|
||||
// 3. models_dev_cache.json
|
||||
const cached = lookupContextFromCache(model, provider)
|
||||
const cached = lookupContextFromCache(config, model, provider)
|
||||
if (cached) return cached
|
||||
|
||||
// 4. Fallback
|
||||
|
||||
@@ -158,4 +158,169 @@ describe('getModelContextLength', () => {
|
||||
|
||||
expect(getModelContextLength()).toBe(1_000_000)
|
||||
})
|
||||
|
||||
it('resolves provider: custom through model.base_url before falling back to the default context length', async () => {
|
||||
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom\n base_url: https://api.deepseek.com\n`)
|
||||
writeModelsCache({
|
||||
deepseek: {
|
||||
models: {
|
||||
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(1_000_000)
|
||||
})
|
||||
|
||||
it('resolves custom:name providers when the matched custom provider base_url points at a builtin provider', async () => {
|
||||
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom:deepseek\n\ncustom_providers:\n - name: deepseek\n base_url: https://api.deepseek.com\n model: deepseek-v4-pro\n`)
|
||||
writeModelsCache({
|
||||
deepseek: {
|
||||
models: {
|
||||
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(1_000_000)
|
||||
})
|
||||
|
||||
it('prefers the builtin provider inferred from a matched custom provider base_url over an arbitrary custom provider name', async () => {
|
||||
writeConfig(`model:\n default: shared-model\n provider: custom:corp-proxy\n\ncustom_providers:\n - name: corp-proxy\n base_url: https://api.deepseek.com\n model: shared-model\n`)
|
||||
writeModelsCache({
|
||||
deepseek: {
|
||||
models: {
|
||||
'shared-model': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
openai: {
|
||||
models: {
|
||||
'shared-model': { limit: { context: 400_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(1_000_000)
|
||||
})
|
||||
|
||||
it('does not trust a stale custom:name provider hint without a matching custom provider entry', async () => {
|
||||
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom:deepseek\n`)
|
||||
writeModelsCache({
|
||||
deepseek: {
|
||||
models: {
|
||||
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(200_000)
|
||||
})
|
||||
|
||||
it('does not trust custom:name alone when the matched custom provider entry points at an unknown proxy url', async () => {
|
||||
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom:deepseek\n\ncustom_providers:\n - name: deepseek\n base_url: https://proxy.example.com/v1\n model: deepseek-v4-pro\n`)
|
||||
writeModelsCache({
|
||||
deepseek: {
|
||||
models: {
|
||||
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(200_000)
|
||||
})
|
||||
|
||||
it('does not fall through to a unique global match after a resolved custom:name provider misses in its scoped cache provider', async () => {
|
||||
writeConfig(`model:\n default: gpt-5.5\n provider: custom:deepseek\n\ncustom_providers:\n - name: deepseek\n base_url: https://api.deepseek.com\n model: gpt-5.5\n`)
|
||||
writeModelsCache({
|
||||
openai: {
|
||||
models: {
|
||||
'gpt-5.5': { limit: { context: 400_000 } },
|
||||
},
|
||||
},
|
||||
deepseek: {
|
||||
models: {
|
||||
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(200_000)
|
||||
})
|
||||
|
||||
it('allows a unique global model-name fallback for unresolved custom providers', async () => {
|
||||
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom\n base_url: https://proxy.example.com/v1\n`)
|
||||
writeModelsCache({
|
||||
deepseek: {
|
||||
models: {
|
||||
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(1_000_000)
|
||||
})
|
||||
|
||||
it('still allows the unique global fallback when provider: custom matches a custom provider entry that cannot be mapped to a builtin cache provider', async () => {
|
||||
writeConfig(`model:\n default: deepseek-v4-pro\n provider: custom\n\ncustom_providers:\n - name: corp-proxy\n base_url: https://proxy.example.com/v1\n model: deepseek-v4-pro\n`)
|
||||
writeModelsCache({
|
||||
deepseek: {
|
||||
models: {
|
||||
'deepseek-v4-pro': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(1_000_000)
|
||||
})
|
||||
|
||||
it('keeps the unresolved custom-provider fallback strict to exact or case-insensitive model-name matches', async () => {
|
||||
writeConfig(`model:\n default: gpt-5\n provider: custom\n base_url: https://proxy.example.com/v1\n`)
|
||||
writeModelsCache({
|
||||
vercel: {
|
||||
models: {
|
||||
'openai/gpt-5': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(200_000)
|
||||
})
|
||||
|
||||
it('does not guess across multiple cache providers when a custom provider remains unresolved', async () => {
|
||||
writeConfig(`model:\n default: shared-model\n provider: custom\n base_url: https://proxy.example.com/v1\n`)
|
||||
writeModelsCache({
|
||||
deepseek: {
|
||||
models: {
|
||||
'shared-model': { limit: { context: 1_000_000 } },
|
||||
},
|
||||
},
|
||||
openai: {
|
||||
models: {
|
||||
'shared-model': { limit: { context: 400_000 } },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength()).toBe(200_000)
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user