update azure

This commit is contained in:
Chester Curme 2025-02-21 15:25:37 -05:00
parent c289fc9ba9
commit 6c2474b220

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import os
from functools import cached_property
from typing import (
Any,
Awaitable,
@ -629,7 +630,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
"Or you can equivalently specify:\n\n"
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
)
client_params: dict = {
self._client_params: dict = {
"api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint,
"azure_deployment": self.deployment_name,
@ -650,27 +651,45 @@ class AzureChatOpenAI(BaseChatOpenAI):
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
if not self.client:
sync_specific = {"http_client": self.http_client}
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions
if not self.async_client:
async_specific = {"http_client": self.http_async_client}
self._client_params["max_retries"] = self.max_retries
if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self._client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)
self.root_async_client = openai.AsyncAzureOpenAI(
**client_params,
**async_specific, # type: ignore[arg-type]
)
self.async_client = self.root_async_client.chat.completions
return self
@cached_property
def _root_client(self) -> openai.AzureOpenAI:
if self.root_client is not None:
return self.root_client
sync_specific = {"http_client": self._http_client}
return openai.AzureOpenAI(**self._client_params, **sync_specific) # type: ignore[call-overload]
@cached_property
def _root_async_client(self) -> openai.AsyncAzureOpenAI:
if self.root_async_client is not None:
return self.root_async_client
async_specific = {"http_client": self._http_async_client}
return openai.AsyncAzureOpenAI(
**self._client_params,
**async_specific, # type: ignore[call-overload]
)
@cached_property
def _client(self) -> Any:
if self.client is not None:
return self.client
return self._root_client.chat.completions
@cached_property
def _async_client(self) -> Any:
if self.async_client is not None:
return self.async_client
return self._root_async_client.chat.completions
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""