mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 22:45:49 +00:00
update azure
This commit is contained in:
parent
c289fc9ba9
commit
6c2474b220
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from functools import cached_property
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
@ -629,7 +630,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
"Or you can equivalently specify:\n\n"
|
"Or you can equivalently specify:\n\n"
|
||||||
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
|
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
|
||||||
)
|
)
|
||||||
client_params: dict = {
|
self._client_params: dict = {
|
||||||
"api_version": self.openai_api_version,
|
"api_version": self.openai_api_version,
|
||||||
"azure_endpoint": self.azure_endpoint,
|
"azure_endpoint": self.azure_endpoint,
|
||||||
"azure_deployment": self.deployment_name,
|
"azure_deployment": self.deployment_name,
|
||||||
@ -650,27 +651,45 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
"default_query": self.default_query,
|
"default_query": self.default_query,
|
||||||
}
|
}
|
||||||
if self.max_retries is not None:
|
if self.max_retries is not None:
|
||||||
client_params["max_retries"] = self.max_retries
|
self._client_params["max_retries"] = self.max_retries
|
||||||
|
|
||||||
if not self.client:
|
if self.azure_ad_async_token_provider:
|
||||||
sync_specific = {"http_client": self.http_client}
|
self._client_params["azure_ad_token_provider"] = (
|
||||||
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
|
self.azure_ad_async_token_provider
|
||||||
self.client = self.root_client.chat.completions
|
|
||||||
if not self.async_client:
|
|
||||||
async_specific = {"http_client": self.http_async_client}
|
|
||||||
|
|
||||||
if self.azure_ad_async_token_provider:
|
|
||||||
client_params["azure_ad_token_provider"] = (
|
|
||||||
self.azure_ad_async_token_provider
|
|
||||||
)
|
|
||||||
|
|
||||||
self.root_async_client = openai.AsyncAzureOpenAI(
|
|
||||||
**client_params,
|
|
||||||
**async_specific, # type: ignore[arg-type]
|
|
||||||
)
|
)
|
||||||
self.async_client = self.root_async_client.chat.completions
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _root_client(self) -> openai.AzureOpenAI:
|
||||||
|
if self.root_client is not None:
|
||||||
|
return self.root_client
|
||||||
|
sync_specific = {"http_client": self._http_client}
|
||||||
|
return openai.AzureOpenAI(**self._client_params, **sync_specific) # type: ignore[call-overload]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _root_async_client(self) -> openai.AsyncAzureOpenAI:
|
||||||
|
if self.root_async_client is not None:
|
||||||
|
return self.root_async_client
|
||||||
|
async_specific = {"http_client": self._http_async_client}
|
||||||
|
|
||||||
|
return openai.AsyncAzureOpenAI(
|
||||||
|
**self._client_params,
|
||||||
|
**async_specific, # type: ignore[call-overload]
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _client(self) -> Any:
|
||||||
|
if self.client is not None:
|
||||||
|
return self.client
|
||||||
|
return self._root_client.chat.completions
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _async_client(self) -> Any:
|
||||||
|
if self.async_client is not None:
|
||||||
|
return self.async_client
|
||||||
|
return self._root_async_client.chat.completions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
|
Loading…
Reference in New Issue
Block a user