Update 2026-05-13 16:43:53
This commit is contained in:
@@ -0,0 +1,309 @@
|
||||
"""Azure OpenAI provider implementation with API version 2024-10-21."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
|
||||
|
||||
class AzureOpenAIProvider(LLMProvider):
|
||||
"""
|
||||
Azure OpenAI provider with API version 2024-10-21 compliance.
|
||||
|
||||
Features:
|
||||
- Hardcoded API version 2024-10-21
|
||||
- Uses model field as Azure deployment name in URL path
|
||||
- Uses api-key header instead of Authorization Bearer
|
||||
- Uses max_completion_tokens instead of max_tokens
|
||||
- Direct HTTP calls, bypasses LiteLLM
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "",
|
||||
api_base: str = "",
|
||||
default_model: str = "gpt-5.2-chat",
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.api_version = "2024-10-21"
|
||||
|
||||
# Validate required parameters
|
||||
if not api_key:
|
||||
raise ValueError("Azure OpenAI api_key is required")
|
||||
if not api_base:
|
||||
raise ValueError("Azure OpenAI api_base is required")
|
||||
|
||||
# Ensure api_base ends with /
|
||||
if not api_base.endswith('/'):
|
||||
api_base += '/'
|
||||
self.api_base = api_base
|
||||
|
||||
def _build_chat_url(self, deployment_name: str) -> str:
|
||||
"""Build the Azure OpenAI chat completions URL."""
|
||||
# Azure OpenAI URL format:
|
||||
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
|
||||
base_url = self.api_base
|
||||
if not base_url.endswith('/'):
|
||||
base_url += '/'
|
||||
|
||||
url = urljoin(
|
||||
base_url,
|
||||
f"openai/deployments/{deployment_name}/chat/completions"
|
||||
)
|
||||
return f"{url}?api-version={self.api_version}"
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Build headers for Azure OpenAI API with api-key header."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
|
||||
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _supports_temperature(
|
||||
deployment_name: str,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> bool:
|
||||
"""Return True when temperature is likely supported for this deployment."""
|
||||
if reasoning_effort:
|
||||
return False
|
||||
name = deployment_name.lower()
|
||||
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||
|
||||
def _prepare_request_payload(
|
||||
self,
|
||||
deployment_name: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||
payload: dict[str, Any] = {
|
||||
"messages": self._sanitize_request_messages(
|
||||
self._sanitize_empty_content(messages),
|
||||
_AZURE_MSG_KEYS,
|
||||
),
|
||||
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||
}
|
||||
|
||||
if self._supports_temperature(deployment_name, reasoning_effort):
|
||||
payload["temperature"] = temperature
|
||||
|
||||
if reasoning_effort:
|
||||
payload["reasoning_effort"] = reasoning_effort
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return payload
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request to Azure OpenAI.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (used as deployment name).
|
||||
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
|
||||
temperature: Sampling temperature.
|
||||
reasoning_effort: Optional reasoning effort parameter.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
return self._parse_response(response_data)
|
||||
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Azure OpenAI: {repr(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse Azure OpenAI response into our standard format."""
|
||||
try:
|
||||
choice = response["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
tool_calls = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
usage = {}
|
||||
if response.get("usage"):
|
||||
usage_data = response["usage"]
|
||||
usage = {
|
||||
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage_data.get("completion_tokens", 0),
|
||||
"total_tokens": usage_data.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
reasoning_content = message.get("reasoning_content") or None
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content"),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
return LLMResponse(
|
||||
content=f"Error parsing Azure OpenAI response: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion via Azure OpenAI SSE."""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice=tool_choice,
|
||||
)
|
||||
payload["stream"] = True
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return await self._consume_stream(response, on_content_delta)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
|
||||
|
||||
async def _consume_stream(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
|
||||
content_parts: list[str] = []
|
||||
tool_call_buffers: dict[int, dict[str, str]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0]
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice["finish_reason"]
|
||||
delta = choice.get("delta") or {}
|
||||
|
||||
text = delta.get("content")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
if on_content_delta:
|
||||
await on_content_delta(text)
|
||||
|
||||
for tc in delta.get("tool_calls") or []:
|
||||
idx = tc.get("index", 0)
|
||||
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
|
||||
if tc.get("id"):
|
||||
buf["id"] = tc["id"]
|
||||
fn = tc.get("function") or {}
|
||||
if fn.get("name"):
|
||||
buf["name"] = fn["name"]
|
||||
if fn.get("arguments"):
|
||||
buf["arguments"] += fn["arguments"]
|
||||
|
||||
tool_calls = [
|
||||
ToolCallRequest(
|
||||
id=buf["id"], name=buf["name"],
|
||||
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
|
||||
)
|
||||
for buf in tool_call_buffers.values()
|
||||
]
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model (also used as default deployment name)."""
|
||||
return self.default_model
|
||||
Reference in New Issue
Block a user