From 6c2474b22057156dbb5b299373a3f90089bac0ff Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Fri, 21 Feb 2025 15:25:37 -0500 Subject: [PATCH] update azure --- .../langchain_openai/chat_models/azure.py | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index 04c5e4e5c1d..1c5653ac3a2 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging import os +from functools import cached_property from typing import ( Any, Awaitable, @@ -629,7 +630,7 @@ class AzureChatOpenAI(BaseChatOpenAI): "Or you can equivalently specify:\n\n" 'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"' ) - client_params: dict = { + self._client_params: dict = { "api_version": self.openai_api_version, "azure_endpoint": self.azure_endpoint, "azure_deployment": self.deployment_name, @@ -650,27 +651,45 @@ class AzureChatOpenAI(BaseChatOpenAI): "default_query": self.default_query, } if self.max_retries is not None: - client_params["max_retries"] = self.max_retries + self._client_params["max_retries"] = self.max_retries - if not self.client: - sync_specific = {"http_client": self.http_client} - self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type] - self.client = self.root_client.chat.completions - if not self.async_client: - async_specific = {"http_client": self.http_async_client} - - if self.azure_ad_async_token_provider: - client_params["azure_ad_token_provider"] = ( - self.azure_ad_async_token_provider - ) - - self.root_async_client = openai.AsyncAzureOpenAI( - **client_params, - **async_specific, # type: ignore[arg-type] + if self.azure_ad_async_token_provider: + self._client_params["azure_ad_token_provider"] = ( + self.azure_ad_async_token_provider ) - self.async_client = self.root_async_client.chat.completions + return self + @cached_property + def _root_client(self) -> openai.AzureOpenAI: + if self.root_client is not None: + return self.root_client + sync_specific = {"http_client": self._http_client} + return openai.AzureOpenAI(**self._client_params, **sync_specific) # type: ignore[call-overload] + + @cached_property + def _root_async_client(self) -> openai.AsyncAzureOpenAI: + if self.root_async_client is not None: + return self.root_async_client + async_specific = {"http_client": self._http_async_client} + + return openai.AsyncAzureOpenAI( + **self._client_params, + **async_specific, # type: ignore[call-overload] + ) + + @cached_property + def _client(self) -> Any: + if self.client is not None: + return self.client + return self._root_client.chat.completions + + @cached_property + def _async_client(self) -> Any: + if self.async_client is not None: + return self.async_client + return self._root_async_client.chat.completions + @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters."""