fix: fix the bug for not recognizing gpt model
This commit is contained in:
@@ -9,6 +9,82 @@ from litellm import acompletion, stream_chunk_builder
|
|||||||
streaming_queue_var = contextvars.ContextVar("streaming_queue", default=None)
|
streaming_queue_var = contextvars.ContextVar("streaming_queue", default=None)
|
||||||
|
|
||||||
class StreamingLiteLLMProvider(LiteLLMProvider):
|
class StreamingLiteLLMProvider(LiteLLMProvider):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._provider_name_override = kwargs.get("provider_name")
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _get_active_spec(self, model: str):
|
||||||
|
from nanobot.providers.registry import find_by_model, find_by_name
|
||||||
|
spec = None
|
||||||
|
if self._provider_name_override:
|
||||||
|
spec = find_by_name(self._provider_name_override)
|
||||||
|
if not spec:
|
||||||
|
spec = find_by_model(model)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
||||||
|
"""Set environment variables based on detected provider."""
|
||||||
|
import os
|
||||||
|
spec = self._gateway or self._get_active_spec(model)
|
||||||
|
if not spec:
|
||||||
|
return
|
||||||
|
if not spec.env_key:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._gateway:
|
||||||
|
os.environ[spec.env_key] = api_key
|
||||||
|
else:
|
||||||
|
os.environ.setdefault(spec.env_key, api_key)
|
||||||
|
|
||||||
|
effective_base = api_base or spec.default_api_base
|
||||||
|
for env_name, env_val in spec.env_extras:
|
||||||
|
resolved = env_val.replace("{api_key}", api_key)
|
||||||
|
resolved = resolved.replace("{api_base}", effective_base)
|
||||||
|
os.environ.setdefault(env_name, resolved)
|
||||||
|
|
||||||
|
def _resolve_model(self, model: str) -> str:
|
||||||
|
"""Resolve model name by applying provider/gateway prefixes, using override if available."""
|
||||||
|
if self._gateway:
|
||||||
|
prefix = self._gateway.litellm_prefix
|
||||||
|
if self._gateway.strip_model_prefix:
|
||||||
|
model = model.split("/")[-1]
|
||||||
|
if prefix and not model.startswith(f"{prefix}/"):
|
||||||
|
model = f"{prefix}/{model}"
|
||||||
|
return model
|
||||||
|
|
||||||
|
spec = self._get_active_spec(model)
|
||||||
|
if spec and spec.litellm_prefix:
|
||||||
|
model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix)
|
||||||
|
if not any(model.startswith(s) for s in spec.skip_prefixes):
|
||||||
|
model = f"{spec.litellm_prefix}/{model}"
|
||||||
|
elif spec and not spec.litellm_prefix and "/" not in model:
|
||||||
|
# For standard providers like openai, anthropic, litellm requires the prefix for unknown models
|
||||||
|
# but registry sets litellm_prefix="" to rely on native matching.
|
||||||
|
# If native matching fails (e.g. non-standard model name), we should force prefix.
|
||||||
|
# We only force prefix if provider was explicitly set and model has no prefix.
|
||||||
|
if self._provider_name_override:
|
||||||
|
model = f"{spec.name}/{model}"
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
|
||||||
|
"""Apply model-specific parameter overrides from the registry."""
|
||||||
|
model_lower = model.lower()
|
||||||
|
spec = self._get_active_spec(model)
|
||||||
|
if spec:
|
||||||
|
for pattern, overrides in spec.model_overrides:
|
||||||
|
if pattern in model_lower:
|
||||||
|
kwargs.update(overrides)
|
||||||
|
return
|
||||||
|
|
||||||
|
def _extra_msg_keys(self, original_model: str, resolved_model: str) -> frozenset[str]:
|
||||||
|
"""Return provider-specific extra keys to preserve in request messages."""
|
||||||
|
spec = self._get_active_spec(original_model) or self._get_active_spec(resolved_model)
|
||||||
|
if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"):
|
||||||
|
# _ANTHROPIC_EXTRA_KEYS is defined in nanobot.providers.litellm_provider, let's just use the string
|
||||||
|
return frozenset({"thinking_blocks"})
|
||||||
|
return frozenset()
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, Any]],
|
messages: List[Dict[str, Any]],
|
||||||
@@ -22,15 +98,21 @@ class StreamingLiteLLMProvider(LiteLLMProvider):
|
|||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
original_model = model or self.default_model
|
original_model = model or self.default_model
|
||||||
model_name = self._resolve_model(original_model)
|
model_name = self._resolve_model(original_model)
|
||||||
|
extra_msg_keys = self._extra_msg_keys(original_model, model_name)
|
||||||
|
|
||||||
|
if self._supports_cache_control(original_model):
|
||||||
|
messages, tools = self._apply_cache_control(messages, tools)
|
||||||
|
|
||||||
kwargs: Dict[str, Any] = {
|
kwargs: Dict[str, Any] = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"messages": messages,
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max(1, max_tokens),
|
||||||
"stream": True, # 强制开启流式
|
"stream": True, # 强制开启流式
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self._apply_model_overrides(model_name, kwargs)
|
||||||
|
|
||||||
if self.api_key and self.api_key != "no-key":
|
if self.api_key and self.api_key != "no-key":
|
||||||
kwargs["api_key"] = self.api_key
|
kwargs["api_key"] = self.api_key
|
||||||
if self.api_base:
|
if self.api_base:
|
||||||
@@ -39,13 +121,15 @@ class StreamingLiteLLMProvider(LiteLLMProvider):
|
|||||||
kwargs["extra_headers"] = self.extra_headers
|
kwargs["extra_headers"] = self.extra_headers
|
||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
|
kwargs["tool_choice"] = "auto"
|
||||||
if request_timeout is not None:
|
if request_timeout is not None:
|
||||||
kwargs["timeout"] = request_timeout
|
kwargs["timeout"] = request_timeout
|
||||||
if num_retries is not None:
|
if num_retries is not None:
|
||||||
kwargs["num_retries"] = max(0, int(num_retries))
|
kwargs["num_retries"] = max(0, int(num_retries))
|
||||||
|
|
||||||
if reasoning_effort and self._supports_reasoning_effort(model_name):
|
if reasoning_effort:
|
||||||
kwargs["reasoning_effort"] = reasoning_effort
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
kwargs["drop_params"] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_stream = await acompletion(**kwargs)
|
response_stream = await acompletion(**kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user