Fix custom provider context resolution (#468)
Co-authored-by: ekko <152005280+EKKOLearnAI@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@ import { resolve, join } from 'path'
|
||||
import { homedir } from 'os'
|
||||
import { readFileSync, existsSync, statSync } from 'fs'
|
||||
import yaml from 'js-yaml'
|
||||
import { PROVIDER_PRESETS } from '../../shared/providers'
|
||||
import { getDb } from '../../db'
|
||||
import { MODEL_CONTEXT_TABLE } from '../../db/hermes/schemas'
|
||||
|
||||
@@ -25,6 +26,13 @@ interface ProviderEntry {
|
||||
models?: Record<string, ModelEntry>
|
||||
}
|
||||
|
||||
interface CustomProviderEntry {
|
||||
name?: string
|
||||
base_url?: string
|
||||
model?: string
|
||||
models?: Record<string, { context_length?: number }>
|
||||
}
|
||||
|
||||
const MODEL_CACHE_PROVIDER_ALIASES: Record<string, string[]> = {
|
||||
gemini: ['google'],
|
||||
moonshot: ['moonshotai'],
|
||||
@@ -107,24 +115,54 @@ function getConfigContextLength(config: any): number | null {
|
||||
return val
|
||||
}
|
||||
|
||||
function normalizeCustomProviderName(name: string): string {
|
||||
return name.trim().toLowerCase().replace(/ /g, '-')
|
||||
}
|
||||
|
||||
function normalizeBaseUrl(url: string): string {
|
||||
return url.trim().toLowerCase().replace(/\/+$/, '')
|
||||
}
|
||||
|
||||
function getModelBaseUrl(config: any): string | null {
|
||||
const model = config?.model
|
||||
if (!model || typeof model !== 'object') return null
|
||||
return typeof model.base_url === 'string' ? model.base_url.trim() || null : null
|
||||
}
|
||||
|
||||
function getCustomProviders(config: any): CustomProviderEntry[] {
|
||||
return Array.isArray(config?.custom_providers) ? config.custom_providers as CustomProviderEntry[] : []
|
||||
}
|
||||
|
||||
function resolveCustomProviderEntry(config: any, modelName: string, provider: string | null): CustomProviderEntry | null {
|
||||
if (!provider || !provider.startsWith('custom')) return null
|
||||
|
||||
const providers = getCustomProviders(config)
|
||||
if (provider !== 'custom') {
|
||||
const suffix = normalizeCustomProviderName(provider.slice('custom:'.length))
|
||||
return providers.find((cp) => normalizeCustomProviderName(String(cp?.name || '')) === suffix) || null
|
||||
}
|
||||
|
||||
const modelBaseUrl = getModelBaseUrl(config)
|
||||
if (modelBaseUrl) {
|
||||
const normalizedBaseUrl = normalizeBaseUrl(modelBaseUrl)
|
||||
const exactByBaseUrl = providers.find((cp) =>
|
||||
normalizeBaseUrl(String(cp?.base_url || '')) === normalizedBaseUrl
|
||||
&& String(cp?.model || '').trim() === modelName,
|
||||
)
|
||||
if (exactByBaseUrl) return exactByBaseUrl
|
||||
}
|
||||
|
||||
const matchesByModel = providers.filter((cp) => String(cp?.model || '').trim() === modelName)
|
||||
return matchesByModel.length === 1 ? matchesByModel[0] : null
|
||||
}
|
||||
|
||||
/**
|
||||
* Lookup context_length from custom_providers in config.yaml.
|
||||
* - "custom:xxx" → strip prefix, match by name
|
||||
* - "custom" → match by model name
|
||||
*/
|
||||
function lookupCustomProviderContextLength(config: any, modelName: string, provider: string | null): number | null {
|
||||
const providers: any[] = Array.isArray(config?.custom_providers) ? config.custom_providers : []
|
||||
if (!provider || !provider.startsWith('custom')) return null
|
||||
|
||||
let matched: any = null
|
||||
|
||||
if (provider === 'custom') {
|
||||
matched = providers.find((cp: any) => cp.model === modelName)
|
||||
} else {
|
||||
const suffix = provider.slice('custom:'.length)
|
||||
matched = providers.find((cp: any) => cp.name === suffix)
|
||||
}
|
||||
|
||||
const matched = resolveCustomProviderEntry(config, modelName, provider)
|
||||
if (!matched) return null
|
||||
|
||||
const models = matched.models
|
||||
@@ -216,11 +254,68 @@ function lookupContextGloballyByModelName(data: Record<string, ProviderEntry>, m
|
||||
return null
|
||||
}
|
||||
|
||||
function lookupContextFromCache(modelName: string, provider: string | null): number | null {
|
||||
function lookupUniqueContextGloballyByModelName(data: Record<string, ProviderEntry>, modelName: string): number | null {
|
||||
const exactMatches: number[] = []
|
||||
for (const prov of Object.values(data)) {
|
||||
const context = getCachedContext(prov.models?.[modelName])
|
||||
if (context) exactMatches.push(context)
|
||||
if (exactMatches.length > 1) return null
|
||||
}
|
||||
if (exactMatches.length === 1) return exactMatches[0]
|
||||
|
||||
const lower = modelName.toLowerCase()
|
||||
const ciMatches: number[] = []
|
||||
for (const prov of Object.values(data)) {
|
||||
const models = prov.models || {}
|
||||
for (const [name, entry] of Object.entries(models)) {
|
||||
if (name.toLowerCase() !== lower) continue
|
||||
const context = getCachedContext(entry)
|
||||
if (context) ciMatches.push(context)
|
||||
break
|
||||
}
|
||||
if (ciMatches.length > 1) return null
|
||||
}
|
||||
|
||||
return ciMatches[0] || null
|
||||
}
|
||||
|
||||
function resolveCacheProviderFromBaseUrl(baseUrl: string | null): string | null {
|
||||
if (!baseUrl) return null
|
||||
const normalizedBaseUrl = normalizeBaseUrl(baseUrl)
|
||||
const preset = PROVIDER_PRESETS.find((entry) => normalizeBaseUrl(entry.base_url) === normalizedBaseUrl)
|
||||
return preset?.value || null
|
||||
}
|
||||
|
||||
function resolveCustomCacheProvider(config: any, modelName: string, provider: string): string | null {
|
||||
const customEntry = resolveCustomProviderEntry(config, modelName, provider)
|
||||
const entryBaseUrl = typeof customEntry?.base_url === 'string' ? customEntry.base_url : null
|
||||
const providerFromEntryBaseUrl = resolveCacheProviderFromBaseUrl(entryBaseUrl)
|
||||
if (providerFromEntryBaseUrl) return providerFromEntryBaseUrl
|
||||
|
||||
return resolveCacheProviderFromBaseUrl(getModelBaseUrl(config))
|
||||
}
|
||||
|
||||
function lookupContextFromCache(config: any, modelName: string, provider: string | null): number | null {
|
||||
const data = loadModelsDevCache()
|
||||
if (!data) return null
|
||||
|
||||
if (provider) {
|
||||
if (provider === 'custom' || provider.startsWith('custom:')) {
|
||||
const inferredProvider = resolveCustomCacheProvider(config, modelName, provider)
|
||||
|
||||
if (inferredProvider) {
|
||||
const scoped = lookupContextInProvider(getProviderEntry(data, inferredProvider), modelName)
|
||||
if (scoped) return scoped
|
||||
return null
|
||||
}
|
||||
|
||||
if (provider === 'custom') {
|
||||
return lookupUniqueContextGloballyByModelName(data, modelName)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
return lookupContextInProvider(getProviderEntry(data, provider), modelName)
|
||||
}
|
||||
|
||||
@@ -278,7 +373,7 @@ export function getModelContextLength(profile?: string): number {
|
||||
if (customCtx && customCtx > 0) return customCtx
|
||||
|
||||
// 3. models_dev_cache.json
|
||||
const cached = lookupContextFromCache(model, provider)
|
||||
const cached = lookupContextFromCache(config, model, provider)
|
||||
if (cached) return cached
|
||||
|
||||
// 4. Fallback
|
||||
|
||||
Reference in New Issue
Block a user