"""Azure OpenAI provider implementation with API version 2024-10-21.""" from __future__ import annotations import json import uuid from collections.abc import Awaitable, Callable from typing import Any from urllib.parse import urljoin import httpx import json_repair from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest _AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"}) class AzureOpenAIProvider(LLMProvider): """ Azure OpenAI provider with API version 2024-10-21 compliance. Features: - Hardcoded API version 2024-10-21 - Uses model field as Azure deployment name in URL path - Uses api-key header instead of Authorization Bearer - Uses max_completion_tokens instead of max_tokens - Direct HTTP calls, bypasses LiteLLM """ def __init__( self, api_key: str = "", api_base: str = "", default_model: str = "gpt-5.2-chat", ): super().__init__(api_key, api_base) self.default_model = default_model self.api_version = "2024-10-21" # Validate required parameters if not api_key: raise ValueError("Azure OpenAI api_key is required") if not api_base: raise ValueError("Azure OpenAI api_base is required") # Ensure api_base ends with / if not api_base.endswith('/'): api_base += '/' self.api_base = api_base def _build_chat_url(self, deployment_name: str) -> str: """Build the Azure OpenAI chat completions URL.""" # Azure OpenAI URL format: # https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} base_url = self.api_base if not base_url.endswith('/'): base_url += '/' url = urljoin( base_url, f"openai/deployments/{deployment_name}/chat/completions" ) return f"{url}?api-version={self.api_version}" def _build_headers(self) -> dict[str, str]: """Build headers for Azure OpenAI API with api-key header.""" return { "Content-Type": "application/json", "api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization "x-session-affinity": uuid.uuid4().hex, # For cache locality } @staticmethod def _supports_temperature( deployment_name: str, reasoning_effort: str | None = None, ) -> bool: """Return True when temperature is likely supported for this deployment.""" if reasoning_effort: return False name = deployment_name.lower() return not any(token in name for token in ("gpt-5", "o1", "o3", "o4")) def _prepare_request_payload( self, deployment_name: str, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, max_tokens: int = 4096, temperature: float = 0.7, reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> dict[str, Any]: """Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" payload: dict[str, Any] = { "messages": self._sanitize_request_messages( self._sanitize_empty_content(messages), _AZURE_MSG_KEYS, ), "max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens } if self._supports_temperature(deployment_name, reasoning_effort): payload["temperature"] = temperature if reasoning_effort: payload["reasoning_effort"] = reasoning_effort if tools: payload["tools"] = tools payload["tool_choice"] = tool_choice or "auto" return payload async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: """ Send a chat completion request to Azure OpenAI. Args: messages: List of message dicts with 'role' and 'content'. tools: Optional list of tool definitions in OpenAI format. model: Model identifier (used as deployment name). max_tokens: Maximum tokens in response (mapped to max_completion_tokens). temperature: Sampling temperature. reasoning_effort: Optional reasoning effort parameter. Returns: LLMResponse with content and/or tool calls. """ deployment_name = model or self.default_model url = self._build_chat_url(deployment_name) headers = self._build_headers() payload = self._prepare_request_payload( deployment_name, messages, tools, max_tokens, temperature, reasoning_effort, tool_choice=tool_choice, ) try: async with httpx.AsyncClient(timeout=60.0, verify=True) as client: response = await client.post(url, headers=headers, json=payload) if response.status_code != 200: return LLMResponse( content=f"Azure OpenAI API Error {response.status_code}: {response.text}", finish_reason="error", ) response_data = response.json() return self._parse_response(response_data) except Exception as e: return LLMResponse( content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error", ) def _parse_response(self, response: dict[str, Any]) -> LLMResponse: """Parse Azure OpenAI response into our standard format.""" try: choice = response["choices"][0] message = choice["message"] tool_calls = [] if message.get("tool_calls"): for tc in message["tool_calls"]: # Parse arguments from JSON string if needed args = tc["function"]["arguments"] if isinstance(args, str): args = json_repair.loads(args) tool_calls.append( ToolCallRequest( id=tc["id"], name=tc["function"]["name"], arguments=args, ) ) usage = {} if response.get("usage"): usage_data = response["usage"] usage = { "prompt_tokens": usage_data.get("prompt_tokens", 0), "completion_tokens": usage_data.get("completion_tokens", 0), "total_tokens": usage_data.get("total_tokens", 0), } reasoning_content = message.get("reasoning_content") or None return LLMResponse( content=message.get("content"), tool_calls=tool_calls, finish_reason=choice.get("finish_reason", "stop"), usage=usage, reasoning_content=reasoning_content, ) except (KeyError, IndexError) as e: return LLMResponse( content=f"Error parsing Azure OpenAI response: {str(e)}", finish_reason="error", ) async def chat_stream( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: """Stream a chat completion via Azure OpenAI SSE.""" deployment_name = model or self.default_model url = self._build_chat_url(deployment_name) headers = self._build_headers() payload = self._prepare_request_payload( deployment_name, messages, tools, max_tokens, temperature, reasoning_effort, tool_choice=tool_choice, ) payload["stream"] = True try: async with httpx.AsyncClient(timeout=60.0, verify=True) as client: async with client.stream("POST", url, headers=headers, json=payload) as response: if response.status_code != 200: text = await response.aread() return LLMResponse( content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}", finish_reason="error", ) return await self._consume_stream(response, on_content_delta) except Exception as e: return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error") async def _consume_stream( self, response: httpx.Response, on_content_delta: Callable[[str], Awaitable[None]] | None, ) -> LLMResponse: """Parse Azure OpenAI SSE stream into an LLMResponse.""" content_parts: list[str] = [] tool_call_buffers: dict[int, dict[str, str]] = {} finish_reason = "stop" async for line in response.aiter_lines(): if not line.startswith("data: "): continue data = line[6:].strip() if data == "[DONE]": break try: chunk = json.loads(data) except Exception: continue choices = chunk.get("choices") or [] if not choices: continue choice = choices[0] if choice.get("finish_reason"): finish_reason = choice["finish_reason"] delta = choice.get("delta") or {} text = delta.get("content") if text: content_parts.append(text) if on_content_delta: await on_content_delta(text) for tc in delta.get("tool_calls") or []: idx = tc.get("index", 0) buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""}) if tc.get("id"): buf["id"] = tc["id"] fn = tc.get("function") or {} if fn.get("name"): buf["name"] = fn["name"] if fn.get("arguments"): buf["arguments"] += fn["arguments"] tool_calls = [ ToolCallRequest( id=buf["id"], name=buf["name"], arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {}, ) for buf in tool_call_buffers.values() ] return LLMResponse( content="".join(content_parts) or None, tool_calls=tool_calls, finish_reason=finish_reason, ) def get_default_model(self) -> str: """Get the default model (also used as default deployment name).""" return self.default_model