mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 06:00:41 +00:00
update azure
This commit is contained in:
parent
c289fc9ba9
commit
6c2474b220
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
@ -629,7 +630,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"Or you can equivalently specify:\n\n"
|
||||
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
|
||||
)
|
||||
client_params: dict = {
|
||||
self._client_params: dict = {
|
||||
"api_version": self.openai_api_version,
|
||||
"azure_endpoint": self.azure_endpoint,
|
||||
"azure_deployment": self.deployment_name,
|
||||
@ -650,27 +651,45 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
if self.max_retries is not None:
|
||||
client_params["max_retries"] = self.max_retries
|
||||
self._client_params["max_retries"] = self.max_retries
|
||||
|
||||
if not self.client:
|
||||
sync_specific = {"http_client": self.http_client}
|
||||
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
|
||||
self.client = self.root_client.chat.completions
|
||||
if not self.async_client:
|
||||
async_specific = {"http_client": self.http_async_client}
|
||||
|
||||
if self.azure_ad_async_token_provider:
|
||||
client_params["azure_ad_token_provider"] = (
|
||||
self.azure_ad_async_token_provider
|
||||
)
|
||||
|
||||
self.root_async_client = openai.AsyncAzureOpenAI(
|
||||
**client_params,
|
||||
**async_specific, # type: ignore[arg-type]
|
||||
if self.azure_ad_async_token_provider:
|
||||
self._client_params["azure_ad_token_provider"] = (
|
||||
self.azure_ad_async_token_provider
|
||||
)
|
||||
self.async_client = self.root_async_client.chat.completions
|
||||
|
||||
return self
|
||||
|
||||
@cached_property
|
||||
def _root_client(self) -> openai.AzureOpenAI:
|
||||
if self.root_client is not None:
|
||||
return self.root_client
|
||||
sync_specific = {"http_client": self._http_client}
|
||||
return openai.AzureOpenAI(**self._client_params, **sync_specific) # type: ignore[call-overload]
|
||||
|
||||
@cached_property
|
||||
def _root_async_client(self) -> openai.AsyncAzureOpenAI:
|
||||
if self.root_async_client is not None:
|
||||
return self.root_async_client
|
||||
async_specific = {"http_client": self._http_async_client}
|
||||
|
||||
return openai.AsyncAzureOpenAI(
|
||||
**self._client_params,
|
||||
**async_specific, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _client(self) -> Any:
|
||||
if self.client is not None:
|
||||
return self.client
|
||||
return self._root_client.chat.completions
|
||||
|
||||
@cached_property
|
||||
def _async_client(self) -> Any:
|
||||
if self.async_client is not None:
|
||||
return self.async_client
|
||||
return self._root_async_client.chat.completions
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
Loading…
Reference in New Issue
Block a user