590 lines
22 KiB
Python
590 lines
22 KiB
Python
|
|
"""OpenAI-compatible provider for all non-Anthropic LLM APIs."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import hashlib
|
||
|
|
import os
|
||
|
|
import secrets
|
||
|
|
import string
|
||
|
|
import uuid
|
||
|
|
from collections.abc import Awaitable, Callable
|
||
|
|
from typing import TYPE_CHECKING, Any
|
||
|
|
|
||
|
|
import json_repair
|
||
|
|
from openai import AsyncOpenAI
|
||
|
|
|
||
|
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from nanobot.providers.registry import ProviderSpec
|
||
|
|
|
||
|
|
_ALLOWED_MSG_KEYS = frozenset({
|
||
|
|
"role", "content", "tool_calls", "tool_call_id", "name",
|
||
|
|
"reasoning_content", "extra_content",
|
||
|
|
})
|
||
|
|
_ALNUM = string.ascii_letters + string.digits
|
||
|
|
|
||
|
|
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
||
|
|
_STANDARD_FN_KEYS = frozenset({"name", "arguments"})
|
||
|
|
_DEFAULT_OPENROUTER_HEADERS = {
|
||
|
|
"HTTP-Referer": "https://github.com/HKUDS/nanobot",
|
||
|
|
"X-OpenRouter-Title": "nanobot",
|
||
|
|
"X-OpenRouter-Categories": "cli-agent,personal-agent",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _short_tool_id() -> str:
|
||
|
|
"""9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
||
|
|
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||
|
|
|
||
|
|
|
||
|
|
def _get(obj: Any, key: str) -> Any:
|
||
|
|
"""Get a value from dict or object attribute, returning None if absent."""
|
||
|
|
if isinstance(obj, dict):
|
||
|
|
return obj.get(key)
|
||
|
|
return getattr(obj, key, None)
|
||
|
|
|
||
|
|
|
||
|
|
def _coerce_dict(value: Any) -> dict[str, Any] | None:
|
||
|
|
"""Try to coerce *value* to a dict; return None if not possible or empty."""
|
||
|
|
if value is None:
|
||
|
|
return None
|
||
|
|
if isinstance(value, dict):
|
||
|
|
return value if value else None
|
||
|
|
model_dump = getattr(value, "model_dump", None)
|
||
|
|
if callable(model_dump):
|
||
|
|
dumped = model_dump()
|
||
|
|
if isinstance(dumped, dict) and dumped:
|
||
|
|
return dumped
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _extract_tc_extras(tc: Any) -> tuple[
|
||
|
|
dict[str, Any] | None,
|
||
|
|
dict[str, Any] | None,
|
||
|
|
dict[str, Any] | None,
|
||
|
|
]:
|
||
|
|
"""Extract (extra_content, provider_specific_fields, fn_provider_specific_fields).
|
||
|
|
|
||
|
|
Works for both SDK objects and dicts. Captures Gemini ``extra_content``
|
||
|
|
verbatim and any non-standard keys on the tool-call / function.
|
||
|
|
"""
|
||
|
|
extra_content = _coerce_dict(_get(tc, "extra_content"))
|
||
|
|
|
||
|
|
tc_dict = _coerce_dict(tc)
|
||
|
|
prov = None
|
||
|
|
fn_prov = None
|
||
|
|
if tc_dict is not None:
|
||
|
|
leftover = {k: v for k, v in tc_dict.items()
|
||
|
|
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
|
||
|
|
if leftover:
|
||
|
|
prov = leftover
|
||
|
|
fn = _coerce_dict(tc_dict.get("function"))
|
||
|
|
if fn is not None:
|
||
|
|
fn_leftover = {k: v for k, v in fn.items()
|
||
|
|
if k not in _STANDARD_FN_KEYS and v is not None}
|
||
|
|
if fn_leftover:
|
||
|
|
fn_prov = fn_leftover
|
||
|
|
else:
|
||
|
|
prov = _coerce_dict(_get(tc, "provider_specific_fields"))
|
||
|
|
fn_obj = _get(tc, "function")
|
||
|
|
if fn_obj is not None:
|
||
|
|
fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields"))
|
||
|
|
|
||
|
|
return extra_content, prov, fn_prov
|
||
|
|
|
||
|
|
|
||
|
|
def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | None) -> bool:
|
||
|
|
"""Apply Nanobot attribution headers to OpenRouter requests by default."""
|
||
|
|
if spec and spec.name == "openrouter":
|
||
|
|
return True
|
||
|
|
return bool(api_base and "openrouter" in api_base.lower())
|
||
|
|
|
||
|
|
|
||
|
|
class OpenAICompatProvider(LLMProvider):
|
||
|
|
"""Unified provider for all OpenAI-compatible APIs.
|
||
|
|
|
||
|
|
Receives a resolved ``ProviderSpec`` from the caller — no internal
|
||
|
|
registry lookups needed.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
api_key: str | None = None,
|
||
|
|
api_base: str | None = None,
|
||
|
|
default_model: str = "gpt-4o",
|
||
|
|
extra_headers: dict[str, str] | None = None,
|
||
|
|
spec: ProviderSpec | None = None,
|
||
|
|
):
|
||
|
|
super().__init__(api_key, api_base)
|
||
|
|
self.default_model = default_model
|
||
|
|
self.extra_headers = extra_headers or {}
|
||
|
|
self._spec = spec
|
||
|
|
|
||
|
|
if api_key and spec and spec.env_key:
|
||
|
|
self._setup_env(api_key, api_base)
|
||
|
|
|
||
|
|
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
||
|
|
default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
||
|
|
if _uses_openrouter_attribution(spec, effective_base):
|
||
|
|
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
||
|
|
if extra_headers:
|
||
|
|
default_headers.update(extra_headers)
|
||
|
|
|
||
|
|
self._client = AsyncOpenAI(
|
||
|
|
api_key=api_key or "no-key",
|
||
|
|
base_url=effective_base,
|
||
|
|
default_headers=default_headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _setup_env(self, api_key: str, api_base: str | None) -> None:
|
||
|
|
"""Set environment variables based on provider spec."""
|
||
|
|
spec = self._spec
|
||
|
|
if not spec or not spec.env_key:
|
||
|
|
return
|
||
|
|
if spec.is_gateway:
|
||
|
|
os.environ[spec.env_key] = api_key
|
||
|
|
else:
|
||
|
|
os.environ.setdefault(spec.env_key, api_key)
|
||
|
|
effective_base = api_base or spec.default_api_base
|
||
|
|
for env_name, env_val in spec.env_extras:
|
||
|
|
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
||
|
|
os.environ.setdefault(env_name, resolved)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _apply_cache_control(
|
||
|
|
messages: list[dict[str, Any]],
|
||
|
|
tools: list[dict[str, Any]] | None,
|
||
|
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||
|
|
"""Inject cache_control markers for prompt caching."""
|
||
|
|
cache_marker = {"type": "ephemeral"}
|
||
|
|
new_messages = list(messages)
|
||
|
|
|
||
|
|
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
content = msg.get("content")
|
||
|
|
if isinstance(content, str):
|
||
|
|
return {**msg, "content": [
|
||
|
|
{"type": "text", "text": content, "cache_control": cache_marker},
|
||
|
|
]}
|
||
|
|
if isinstance(content, list) and content:
|
||
|
|
nc = list(content)
|
||
|
|
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
||
|
|
return {**msg, "content": nc}
|
||
|
|
return msg
|
||
|
|
|
||
|
|
if new_messages and new_messages[0].get("role") == "system":
|
||
|
|
new_messages[0] = _mark(new_messages[0])
|
||
|
|
if len(new_messages) >= 3:
|
||
|
|
new_messages[-2] = _mark(new_messages[-2])
|
||
|
|
|
||
|
|
new_tools = tools
|
||
|
|
if tools:
|
||
|
|
new_tools = list(tools)
|
||
|
|
new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker}
|
||
|
|
return new_messages, new_tools
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
||
|
|
"""Normalize to a provider-safe 9-char alphanumeric form."""
|
||
|
|
if not isinstance(tool_call_id, str):
|
||
|
|
return tool_call_id
|
||
|
|
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
||
|
|
return tool_call_id
|
||
|
|
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
||
|
|
|
||
|
|
def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||
|
|
"""Strip non-standard keys, normalize tool_call IDs."""
|
||
|
|
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
||
|
|
id_map: dict[str, str] = {}
|
||
|
|
|
||
|
|
def map_id(value: Any) -> Any:
|
||
|
|
if not isinstance(value, str):
|
||
|
|
return value
|
||
|
|
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
||
|
|
|
||
|
|
for clean in sanitized:
|
||
|
|
if isinstance(clean.get("tool_calls"), list):
|
||
|
|
normalized = []
|
||
|
|
for tc in clean["tool_calls"]:
|
||
|
|
if not isinstance(tc, dict):
|
||
|
|
normalized.append(tc)
|
||
|
|
continue
|
||
|
|
tc_clean = dict(tc)
|
||
|
|
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||
|
|
normalized.append(tc_clean)
|
||
|
|
clean["tool_calls"] = normalized
|
||
|
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||
|
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||
|
|
return sanitized
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Build kwargs
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def _build_kwargs(
|
||
|
|
self,
|
||
|
|
messages: list[dict[str, Any]],
|
||
|
|
tools: list[dict[str, Any]] | None,
|
||
|
|
model: str | None,
|
||
|
|
max_tokens: int,
|
||
|
|
temperature: float,
|
||
|
|
reasoning_effort: str | None,
|
||
|
|
tool_choice: str | dict[str, Any] | None,
|
||
|
|
) -> dict[str, Any]:
|
||
|
|
model_name = model or self.default_model
|
||
|
|
spec = self._spec
|
||
|
|
|
||
|
|
if spec and spec.supports_prompt_caching:
|
||
|
|
messages, tools = self._apply_cache_control(messages, tools)
|
||
|
|
|
||
|
|
if spec and spec.strip_model_prefix:
|
||
|
|
model_name = model_name.split("/")[-1]
|
||
|
|
|
||
|
|
kwargs: dict[str, Any] = {
|
||
|
|
"model": model_name,
|
||
|
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
||
|
|
"temperature": temperature,
|
||
|
|
}
|
||
|
|
|
||
|
|
if spec and getattr(spec, "supports_max_completion_tokens", False):
|
||
|
|
kwargs["max_completion_tokens"] = max(1, max_tokens)
|
||
|
|
else:
|
||
|
|
kwargs["max_tokens"] = max(1, max_tokens)
|
||
|
|
|
||
|
|
if spec:
|
||
|
|
model_lower = model_name.lower()
|
||
|
|
for pattern, overrides in spec.model_overrides:
|
||
|
|
if pattern in model_lower:
|
||
|
|
kwargs.update(overrides)
|
||
|
|
break
|
||
|
|
|
||
|
|
if reasoning_effort:
|
||
|
|
kwargs["reasoning_effort"] = reasoning_effort
|
||
|
|
|
||
|
|
if tools:
|
||
|
|
kwargs["tools"] = tools
|
||
|
|
kwargs["tool_choice"] = tool_choice or "auto"
|
||
|
|
|
||
|
|
return kwargs
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Response parsing
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _maybe_mapping(value: Any) -> dict[str, Any] | None:
|
||
|
|
if isinstance(value, dict):
|
||
|
|
return value
|
||
|
|
model_dump = getattr(value, "model_dump", None)
|
||
|
|
if callable(model_dump):
|
||
|
|
dumped = model_dump()
|
||
|
|
if isinstance(dumped, dict):
|
||
|
|
return dumped
|
||
|
|
return None
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def _extract_text_content(cls, value: Any) -> str | None:
|
||
|
|
if value is None:
|
||
|
|
return None
|
||
|
|
if isinstance(value, str):
|
||
|
|
return value
|
||
|
|
if isinstance(value, list):
|
||
|
|
parts: list[str] = []
|
||
|
|
for item in value:
|
||
|
|
item_map = cls._maybe_mapping(item)
|
||
|
|
if item_map:
|
||
|
|
text = item_map.get("text")
|
||
|
|
if isinstance(text, str):
|
||
|
|
parts.append(text)
|
||
|
|
continue
|
||
|
|
text = getattr(item, "text", None)
|
||
|
|
if isinstance(text, str):
|
||
|
|
parts.append(text)
|
||
|
|
continue
|
||
|
|
if isinstance(item, str):
|
||
|
|
parts.append(item)
|
||
|
|
return "".join(parts) or None
|
||
|
|
return str(value)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def _extract_usage(cls, response: Any) -> dict[str, int]:
|
||
|
|
usage_obj = None
|
||
|
|
response_map = cls._maybe_mapping(response)
|
||
|
|
if response_map is not None:
|
||
|
|
usage_obj = response_map.get("usage")
|
||
|
|
elif hasattr(response, "usage") and response.usage:
|
||
|
|
usage_obj = response.usage
|
||
|
|
|
||
|
|
usage_map = cls._maybe_mapping(usage_obj)
|
||
|
|
if usage_map is not None:
|
||
|
|
return {
|
||
|
|
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
|
||
|
|
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
|
||
|
|
"total_tokens": int(usage_map.get("total_tokens") or 0),
|
||
|
|
}
|
||
|
|
|
||
|
|
if usage_obj:
|
||
|
|
return {
|
||
|
|
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
|
||
|
|
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
|
||
|
|
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
|
||
|
|
}
|
||
|
|
return {}
|
||
|
|
|
||
|
|
def _parse(self, response: Any) -> LLMResponse:
|
||
|
|
if isinstance(response, str):
|
||
|
|
return LLMResponse(content=response, finish_reason="stop")
|
||
|
|
|
||
|
|
response_map = self._maybe_mapping(response)
|
||
|
|
if response_map is not None:
|
||
|
|
choices = response_map.get("choices") or []
|
||
|
|
if not choices:
|
||
|
|
content = self._extract_text_content(
|
||
|
|
response_map.get("content") or response_map.get("output_text")
|
||
|
|
)
|
||
|
|
if content is not None:
|
||
|
|
return LLMResponse(
|
||
|
|
content=content,
|
||
|
|
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
||
|
|
usage=self._extract_usage(response_map),
|
||
|
|
)
|
||
|
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
||
|
|
|
||
|
|
choice0 = self._maybe_mapping(choices[0]) or {}
|
||
|
|
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
||
|
|
content = self._extract_text_content(msg0.get("content"))
|
||
|
|
finish_reason = str(choice0.get("finish_reason") or "stop")
|
||
|
|
|
||
|
|
raw_tool_calls: list[Any] = []
|
||
|
|
reasoning_content = msg0.get("reasoning_content")
|
||
|
|
for ch in choices:
|
||
|
|
ch_map = self._maybe_mapping(ch) or {}
|
||
|
|
m = self._maybe_mapping(ch_map.get("message")) or {}
|
||
|
|
tool_calls = m.get("tool_calls")
|
||
|
|
if isinstance(tool_calls, list) and tool_calls:
|
||
|
|
raw_tool_calls.extend(tool_calls)
|
||
|
|
if ch_map.get("finish_reason") in ("tool_calls", "stop"):
|
||
|
|
finish_reason = str(ch_map["finish_reason"])
|
||
|
|
if not content:
|
||
|
|
content = self._extract_text_content(m.get("content"))
|
||
|
|
if not reasoning_content:
|
||
|
|
reasoning_content = m.get("reasoning_content")
|
||
|
|
|
||
|
|
parsed_tool_calls = []
|
||
|
|
for tc in raw_tool_calls:
|
||
|
|
tc_map = self._maybe_mapping(tc) or {}
|
||
|
|
fn = self._maybe_mapping(tc_map.get("function")) or {}
|
||
|
|
args = fn.get("arguments", {})
|
||
|
|
if isinstance(args, str):
|
||
|
|
args = json_repair.loads(args)
|
||
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||
|
|
parsed_tool_calls.append(ToolCallRequest(
|
||
|
|
id=_short_tool_id(),
|
||
|
|
name=str(fn.get("name") or ""),
|
||
|
|
arguments=args if isinstance(args, dict) else {},
|
||
|
|
extra_content=ec,
|
||
|
|
provider_specific_fields=prov,
|
||
|
|
function_provider_specific_fields=fn_prov,
|
||
|
|
))
|
||
|
|
|
||
|
|
return LLMResponse(
|
||
|
|
content=content,
|
||
|
|
tool_calls=parsed_tool_calls,
|
||
|
|
finish_reason=finish_reason,
|
||
|
|
usage=self._extract_usage(response_map),
|
||
|
|
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
||
|
|
)
|
||
|
|
|
||
|
|
if not response.choices:
|
||
|
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
||
|
|
|
||
|
|
choice = response.choices[0]
|
||
|
|
msg = choice.message
|
||
|
|
content = msg.content
|
||
|
|
finish_reason = choice.finish_reason
|
||
|
|
|
||
|
|
raw_tool_calls: list[Any] = []
|
||
|
|
for ch in response.choices:
|
||
|
|
m = ch.message
|
||
|
|
if hasattr(m, "tool_calls") and m.tool_calls:
|
||
|
|
raw_tool_calls.extend(m.tool_calls)
|
||
|
|
if ch.finish_reason in ("tool_calls", "stop"):
|
||
|
|
finish_reason = ch.finish_reason
|
||
|
|
if not content and m.content:
|
||
|
|
content = m.content
|
||
|
|
|
||
|
|
tool_calls = []
|
||
|
|
for tc in raw_tool_calls:
|
||
|
|
args = tc.function.arguments
|
||
|
|
if isinstance(args, str):
|
||
|
|
args = json_repair.loads(args)
|
||
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||
|
|
tool_calls.append(ToolCallRequest(
|
||
|
|
id=_short_tool_id(),
|
||
|
|
name=tc.function.name,
|
||
|
|
arguments=args,
|
||
|
|
extra_content=ec,
|
||
|
|
provider_specific_fields=prov,
|
||
|
|
function_provider_specific_fields=fn_prov,
|
||
|
|
))
|
||
|
|
|
||
|
|
return LLMResponse(
|
||
|
|
content=content,
|
||
|
|
tool_calls=tool_calls,
|
||
|
|
finish_reason=finish_reason or "stop",
|
||
|
|
usage=self._extract_usage(response),
|
||
|
|
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||
|
|
)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse:
|
||
|
|
content_parts: list[str] = []
|
||
|
|
tc_bufs: dict[int, dict[str, Any]] = {}
|
||
|
|
finish_reason = "stop"
|
||
|
|
usage: dict[str, int] = {}
|
||
|
|
|
||
|
|
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
||
|
|
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
||
|
|
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
||
|
|
buf = tc_bufs.setdefault(tc_index, {
|
||
|
|
"id": "", "name": "", "arguments": "",
|
||
|
|
"extra_content": None, "prov": None, "fn_prov": None,
|
||
|
|
})
|
||
|
|
tc_id = _get(tc, "id")
|
||
|
|
if tc_id:
|
||
|
|
buf["id"] = str(tc_id)
|
||
|
|
fn = _get(tc, "function")
|
||
|
|
if fn is not None:
|
||
|
|
fn_name = _get(fn, "name")
|
||
|
|
if fn_name:
|
||
|
|
buf["name"] = str(fn_name)
|
||
|
|
fn_args = _get(fn, "arguments")
|
||
|
|
if fn_args:
|
||
|
|
buf["arguments"] += str(fn_args)
|
||
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||
|
|
if ec:
|
||
|
|
buf["extra_content"] = ec
|
||
|
|
if prov:
|
||
|
|
buf["prov"] = prov
|
||
|
|
if fn_prov:
|
||
|
|
buf["fn_prov"] = fn_prov
|
||
|
|
|
||
|
|
for chunk in chunks:
|
||
|
|
if isinstance(chunk, str):
|
||
|
|
content_parts.append(chunk)
|
||
|
|
continue
|
||
|
|
|
||
|
|
chunk_map = cls._maybe_mapping(chunk)
|
||
|
|
if chunk_map is not None:
|
||
|
|
choices = chunk_map.get("choices") or []
|
||
|
|
if not choices:
|
||
|
|
usage = cls._extract_usage(chunk_map) or usage
|
||
|
|
text = cls._extract_text_content(
|
||
|
|
chunk_map.get("content") or chunk_map.get("output_text")
|
||
|
|
)
|
||
|
|
if text:
|
||
|
|
content_parts.append(text)
|
||
|
|
continue
|
||
|
|
choice = cls._maybe_mapping(choices[0]) or {}
|
||
|
|
if choice.get("finish_reason"):
|
||
|
|
finish_reason = str(choice["finish_reason"])
|
||
|
|
delta = cls._maybe_mapping(choice.get("delta")) or {}
|
||
|
|
text = cls._extract_text_content(delta.get("content"))
|
||
|
|
if text:
|
||
|
|
content_parts.append(text)
|
||
|
|
for idx, tc in enumerate(delta.get("tool_calls") or []):
|
||
|
|
_accum_tc(tc, idx)
|
||
|
|
usage = cls._extract_usage(chunk_map) or usage
|
||
|
|
continue
|
||
|
|
|
||
|
|
if not chunk.choices:
|
||
|
|
usage = cls._extract_usage(chunk) or usage
|
||
|
|
continue
|
||
|
|
choice = chunk.choices[0]
|
||
|
|
if choice.finish_reason:
|
||
|
|
finish_reason = choice.finish_reason
|
||
|
|
delta = choice.delta
|
||
|
|
if delta and delta.content:
|
||
|
|
content_parts.append(delta.content)
|
||
|
|
for tc in (delta.tool_calls or []) if delta else []:
|
||
|
|
_accum_tc(tc, getattr(tc, "index", 0))
|
||
|
|
|
||
|
|
return LLMResponse(
|
||
|
|
content="".join(content_parts) or None,
|
||
|
|
tool_calls=[
|
||
|
|
ToolCallRequest(
|
||
|
|
id=b["id"] or _short_tool_id(),
|
||
|
|
name=b["name"],
|
||
|
|
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
|
||
|
|
extra_content=b.get("extra_content"),
|
||
|
|
provider_specific_fields=b.get("prov"),
|
||
|
|
function_provider_specific_fields=b.get("fn_prov"),
|
||
|
|
)
|
||
|
|
for b in tc_bufs.values()
|
||
|
|
],
|
||
|
|
finish_reason=finish_reason,
|
||
|
|
usage=usage,
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _handle_error(e: Exception) -> LLMResponse:
|
||
|
|
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
||
|
|
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}"
|
||
|
|
return LLMResponse(content=msg, finish_reason="error")
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Public API
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
async def chat(
|
||
|
|
self,
|
||
|
|
messages: list[dict[str, Any]],
|
||
|
|
tools: list[dict[str, Any]] | None = None,
|
||
|
|
model: str | None = None,
|
||
|
|
max_tokens: int = 4096,
|
||
|
|
temperature: float = 0.7,
|
||
|
|
reasoning_effort: str | None = None,
|
||
|
|
tool_choice: str | dict[str, Any] | None = None,
|
||
|
|
) -> LLMResponse:
|
||
|
|
kwargs = self._build_kwargs(
|
||
|
|
messages, tools, model, max_tokens, temperature,
|
||
|
|
reasoning_effort, tool_choice,
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||
|
|
except Exception as e:
|
||
|
|
return self._handle_error(e)
|
||
|
|
|
||
|
|
async def chat_stream(
|
||
|
|
self,
|
||
|
|
messages: list[dict[str, Any]],
|
||
|
|
tools: list[dict[str, Any]] | None = None,
|
||
|
|
model: str | None = None,
|
||
|
|
max_tokens: int = 4096,
|
||
|
|
temperature: float = 0.7,
|
||
|
|
reasoning_effort: str | None = None,
|
||
|
|
tool_choice: str | dict[str, Any] | None = None,
|
||
|
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||
|
|
) -> LLMResponse:
|
||
|
|
kwargs = self._build_kwargs(
|
||
|
|
messages, tools, model, max_tokens, temperature,
|
||
|
|
reasoning_effort, tool_choice,
|
||
|
|
)
|
||
|
|
kwargs["stream"] = True
|
||
|
|
kwargs["stream_options"] = {"include_usage": True}
|
||
|
|
try:
|
||
|
|
stream = await self._client.chat.completions.create(**kwargs)
|
||
|
|
chunks: list[Any] = []
|
||
|
|
async for chunk in stream:
|
||
|
|
chunks.append(chunk)
|
||
|
|
if on_content_delta and chunk.choices:
|
||
|
|
text = getattr(chunk.choices[0].delta, "content", None)
|
||
|
|
if text:
|
||
|
|
await on_content_delta(text)
|
||
|
|
return self._parse_chunks(chunks)
|
||
|
|
except Exception as e:
|
||
|
|
return self._handle_error(e)
|
||
|
|
|
||
|
|
def get_default_model(self) -> str:
|
||
|
|
return self.default_model
|