fix(context): prefer provider context lengths (#1184)
This commit is contained in:
@@ -39,6 +39,15 @@ interface CustomProviderEntry {
|
|||||||
models?: Record<string, { context_length?: number }>
|
models?: Record<string, { context_length?: number }>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ConfigProviderModels = Record<string, { context_length?: number } | string> | string[]
|
||||||
|
|
||||||
|
interface ConfigProviderEntry {
|
||||||
|
context_length?: number
|
||||||
|
default_model?: string
|
||||||
|
model?: string
|
||||||
|
models?: ConfigProviderModels
|
||||||
|
}
|
||||||
|
|
||||||
const MODEL_CACHE_PROVIDER_ALIASES: Record<string, string[]> = {
|
const MODEL_CACHE_PROVIDER_ALIASES: Record<string, string[]> = {
|
||||||
gemini: ['google'],
|
gemini: ['google'],
|
||||||
moonshot: ['moonshotai'],
|
moonshot: ['moonshotai'],
|
||||||
@@ -122,6 +131,46 @@ function getConfigContextLength(config: any): number | null {
|
|||||||
return val
|
return val
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getConfigProvider(config: any, provider: string | null): ConfigProviderEntry | null {
|
||||||
|
if (!provider) return null
|
||||||
|
const providers = config?.providers
|
||||||
|
if (!providers || typeof providers !== 'object') return null
|
||||||
|
const exact = providers[provider]
|
||||||
|
if (exact && typeof exact === 'object') return exact as ConfigProviderEntry
|
||||||
|
const lower = provider.toLowerCase()
|
||||||
|
const match = Object.entries(providers).find(([name]) => name.toLowerCase() === lower)
|
||||||
|
const value = match?.[1]
|
||||||
|
return value && typeof value === 'object' ? value as ConfigProviderEntry : null
|
||||||
|
}
|
||||||
|
|
||||||
|
function getPositiveNumber(value: unknown): number | null {
|
||||||
|
return typeof value === 'number' && Number.isFinite(value) && value > 0 ? value : null
|
||||||
|
}
|
||||||
|
|
||||||
|
function providerHasModel(provider: ConfigProviderEntry, modelName: string): boolean {
|
||||||
|
if (provider.default_model === modelName || provider.model === modelName) return true
|
||||||
|
const models = provider.models
|
||||||
|
if (Array.isArray(models)) return models.includes(modelName)
|
||||||
|
return !!models && typeof models === 'object' && Object.prototype.hasOwnProperty.call(models, modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
function lookupProviderConfigContextLength(config: any, modelName: string, provider: string | null): number | null {
|
||||||
|
const providerEntry = getConfigProvider(config, provider)
|
||||||
|
if (!providerEntry) return null
|
||||||
|
|
||||||
|
const models = providerEntry.models
|
||||||
|
if (models && !Array.isArray(models) && typeof models === 'object') {
|
||||||
|
const modelEntry = models[modelName]
|
||||||
|
if (modelEntry && typeof modelEntry === 'object') {
|
||||||
|
const modelCtx = getPositiveNumber(modelEntry.context_length)
|
||||||
|
if (modelCtx) return modelCtx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!providerHasModel(providerEntry, modelName)) return null
|
||||||
|
return getPositiveNumber(providerEntry.context_length)
|
||||||
|
}
|
||||||
|
|
||||||
function normalizeCustomProviderName(name: string): string {
|
function normalizeCustomProviderName(name: string): string {
|
||||||
return name.trim().toLowerCase().replace(/ /g, '-')
|
return name.trim().toLowerCase().replace(/ /g, '-')
|
||||||
}
|
}
|
||||||
@@ -333,10 +382,13 @@ function lookupContextFromCache(config: any, modelName: string, provider: string
|
|||||||
/**
|
/**
|
||||||
* Get the context length for the current profile's default model.
|
* Get the context length for the current profile's default model.
|
||||||
* Resolution order:
|
* Resolution order:
|
||||||
* 1. config.yaml model.context_length (highest priority, user override)
|
* 1. model_context database override
|
||||||
* 2. custom_providers models.<model>.context_length
|
* 2. provider/model-specific providers.<provider>.models.<model>.context_length
|
||||||
* 3. models_dev_cache.json, scoped to model.provider when configured
|
* 3. provider-level providers.<provider>.context_length when the model belongs to that provider
|
||||||
* 4. DEFAULT_CONTEXT_LENGTH (200K hardcoded fallback)
|
* 4. custom_providers models.<model>.context_length
|
||||||
|
* 5. top-level model.context_length fallback
|
||||||
|
* 6. models_dev_cache.json, scoped to model.provider when configured
|
||||||
|
* 7. DEFAULT_CONTEXT_LENGTH
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* 从数据库 model_context 表查找上下文长度(最高优先级)
|
* 从数据库 model_context 表查找上下文长度(最高优先级)
|
||||||
@@ -375,18 +427,22 @@ export function getModelContextLength(input?: string | ModelContextLengthOptions
|
|||||||
const dbCtx = lookupContextFromDatabase(model, provider)
|
const dbCtx = lookupContextFromDatabase(model, provider)
|
||||||
if (dbCtx && dbCtx > 0) return dbCtx
|
if (dbCtx && dbCtx > 0) return dbCtx
|
||||||
|
|
||||||
// 1. Global context_length override in config.yaml
|
// 1. Provider-specific context_length in config.yaml
|
||||||
const configCtx = getConfigContextLength(config)
|
const providerConfigCtx = lookupProviderConfigContextLength(config, model, provider)
|
||||||
if (configCtx && configCtx > 0) return configCtx
|
if (providerConfigCtx && providerConfigCtx > 0) return providerConfigCtx
|
||||||
|
|
||||||
// 2. Custom provider context_length
|
// 2. Custom provider context_length
|
||||||
const customCtx = lookupCustomProviderContextLength(config, model, provider)
|
const customCtx = lookupCustomProviderContextLength(config, model, provider)
|
||||||
if (customCtx && customCtx > 0) return customCtx
|
if (customCtx && customCtx > 0) return customCtx
|
||||||
|
|
||||||
// 3. models_dev_cache.json
|
// 3. Global context_length fallback in config.yaml
|
||||||
|
const configCtx = getConfigContextLength(config)
|
||||||
|
if (configCtx && configCtx > 0) return configCtx
|
||||||
|
|
||||||
|
// 4. models_dev_cache.json
|
||||||
const cached = lookupContextFromCache(config, model, provider)
|
const cached = lookupContextFromCache(config, model, provider)
|
||||||
if (cached) return cached
|
if (cached) return cached
|
||||||
|
|
||||||
// 4. Fallback
|
// 5. Fallback
|
||||||
return DEFAULT_CONTEXT_LENGTH
|
return DEFAULT_CONTEXT_LENGTH
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,6 +109,22 @@ describe('getModelContextLength', () => {
|
|||||||
expect(getModelContextLength()).toBe(1_050_000)
|
expect(getModelContextLength()).toBe(1_050_000)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('prefers requested provider model context_length over top-level default context_length', async () => {
|
||||||
|
writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n context_length: 272000\n\nproviders:\n qwen:\n name: Qwen\n default_model: qwen3.6-plus\n models:\n qwen3.6-plus:\n context_length: 1048576\n`)
|
||||||
|
|
||||||
|
const { getModelContextLength } = await loadModelContext()
|
||||||
|
|
||||||
|
expect(getModelContextLength({ provider: 'qwen', model: 'qwen3.6-plus' })).toBe(1_048_576)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('uses provider-level context_length when the requested model belongs to that provider', async () => {
|
||||||
|
writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n context_length: 272000\n\nproviders:\n qwen:\n name: Qwen\n default_model: qwen3.6-plus\n models:\n - qwen3.6-plus\n context_length: 1048576\n`)
|
||||||
|
|
||||||
|
const { getModelContextLength } = await loadModelContext()
|
||||||
|
|
||||||
|
expect(getModelContextLength({ provider: 'qwen', model: 'qwen3.6-plus' })).toBe(1_048_576)
|
||||||
|
})
|
||||||
|
|
||||||
it('keeps legacy model-name cache lookup when no provider is configured', async () => {
|
it('keeps legacy model-name cache lookup when no provider is configured', async () => {
|
||||||
writeConfig(`model:\n default: gpt-5.5\n`)
|
writeConfig(`model:\n default: gpt-5.5\n`)
|
||||||
writeModelsCache({
|
writeModelsCache({
|
||||||
|
|||||||
Reference in New Issue
Block a user