mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
Using `pyupgrade` to get all `partners` code up to 3.9 standards (mostly, fixing old `typing` imports).
218 lines
8.2 KiB
Python
218 lines
8.2 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections.abc import Awaitable, Mapping
|
|
from typing import Any, Callable, Optional, Union, cast
|
|
|
|
import openai
|
|
from langchain_core.language_models import LangSmithParams
|
|
from langchain_core.utils import from_env, secret_from_env
|
|
from pydantic import Field, SecretStr, model_validator
|
|
from typing_extensions import Self
|
|
|
|
from langchain_openai.llms.base import BaseOpenAI
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AzureOpenAI(BaseOpenAI):
|
|
"""Azure-specific OpenAI large language models.
|
|
|
|
To use, you should have the ``openai`` python package installed, and the
|
|
environment variable ``OPENAI_API_KEY`` set with your API key.
|
|
|
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
|
in, even if not explicitly saved on this class.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_openai import AzureOpenAI
|
|
|
|
openai = AzureOpenAI(model_name="gpt-3.5-turbo-instruct")
|
|
"""
|
|
|
|
azure_endpoint: Optional[str] = Field(
|
|
default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None)
|
|
)
|
|
"""Your Azure endpoint, including the resource.
|
|
|
|
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
|
|
|
|
Example: `https://example-resource.azure.openai.com/`
|
|
"""
|
|
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
|
|
"""A model deployment.
|
|
|
|
If given sets the base client URL to include `/deployments/{azure_deployment}`.
|
|
Note: this means you won't be able to use non-deployment endpoints.
|
|
"""
|
|
openai_api_version: Optional[str] = Field(
|
|
alias="api_version",
|
|
default_factory=from_env("OPENAI_API_VERSION", default=None),
|
|
)
|
|
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
|
# Check OPENAI_KEY for backwards compatibility.
|
|
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
|
|
# other forms of azure credentials.
|
|
openai_api_key: Optional[SecretStr] = Field(
|
|
alias="api_key",
|
|
default_factory=secret_from_env(
|
|
["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None
|
|
),
|
|
)
|
|
azure_ad_token: Optional[SecretStr] = Field(
|
|
default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None)
|
|
)
|
|
"""Your Azure Active Directory token.
|
|
|
|
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
|
|
|
|
For more:
|
|
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
|
|
"""
|
|
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
|
"""A function that returns an Azure Active Directory token.
|
|
|
|
Will be invoked on every sync request. For async requests,
|
|
will be invoked if `azure_ad_async_token_provider` is not provided.
|
|
"""
|
|
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
|
"""A function that returns an Azure Active Directory token.
|
|
|
|
Will be invoked on every async request.
|
|
"""
|
|
openai_api_type: Optional[str] = Field(
|
|
default_factory=from_env("OPENAI_API_TYPE", default="azure")
|
|
)
|
|
"""Legacy, for openai<1.0.0 support."""
|
|
validate_base_url: bool = True
|
|
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
|
|
infer if it is a base_url or azure_endpoint and update accordingly.
|
|
"""
|
|
|
|
@classmethod
|
|
def get_lc_namespace(cls) -> list[str]:
|
|
"""Get the namespace of the langchain object."""
|
|
return ["langchain", "llms", "openai"]
|
|
|
|
@property
|
|
def lc_secrets(self) -> dict[str, str]:
|
|
return {
|
|
"openai_api_key": "AZURE_OPENAI_API_KEY",
|
|
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
|
|
}
|
|
|
|
@classmethod
|
|
def is_lc_serializable(cls) -> bool:
|
|
"""Return whether this model can be serialized by Langchain."""
|
|
return True
|
|
|
|
@model_validator(mode="after")
|
|
def validate_environment(self) -> Self:
|
|
"""Validate that api key and python package exists in environment."""
|
|
if self.n < 1:
|
|
raise ValueError("n must be at least 1.")
|
|
if self.streaming and self.n > 1:
|
|
raise ValueError("Cannot stream results when n > 1.")
|
|
if self.streaming and self.best_of > 1:
|
|
raise ValueError("Cannot stream results when best_of > 1.")
|
|
# For backwards compatibility. Before openai v1, no distinction was made
|
|
# between azure_endpoint and base_url (openai_api_base).
|
|
openai_api_base = self.openai_api_base
|
|
if openai_api_base and self.validate_base_url:
|
|
if "/openai" not in openai_api_base:
|
|
self.openai_api_base = (
|
|
cast(str, self.openai_api_base).rstrip("/") + "/openai"
|
|
)
|
|
raise ValueError(
|
|
"As of openai>=1.0.0, Azure endpoints should be specified via "
|
|
"the `azure_endpoint` param not `openai_api_base` "
|
|
"(or alias `base_url`)."
|
|
)
|
|
if self.deployment_name:
|
|
raise ValueError(
|
|
"As of openai>=1.0.0, if `deployment_name` (or alias "
|
|
"`azure_deployment`) is specified then "
|
|
"`openai_api_base` (or alias `base_url`) should not be. "
|
|
"Instead use `deployment_name` (or alias `azure_deployment`) "
|
|
"and `azure_endpoint`."
|
|
)
|
|
self.deployment_name = None
|
|
client_params: dict = {
|
|
"api_version": self.openai_api_version,
|
|
"azure_endpoint": self.azure_endpoint,
|
|
"azure_deployment": self.deployment_name,
|
|
"api_key": self.openai_api_key.get_secret_value()
|
|
if self.openai_api_key
|
|
else None,
|
|
"azure_ad_token": self.azure_ad_token.get_secret_value()
|
|
if self.azure_ad_token
|
|
else None,
|
|
"azure_ad_token_provider": self.azure_ad_token_provider,
|
|
"organization": self.openai_organization,
|
|
"base_url": self.openai_api_base,
|
|
"timeout": self.request_timeout,
|
|
"max_retries": self.max_retries,
|
|
"default_headers": {
|
|
**(self.default_headers or {}),
|
|
"User-Agent": "langchain-partner-python-azure-openai",
|
|
},
|
|
"default_query": self.default_query,
|
|
}
|
|
if not self.client:
|
|
sync_specific = {"http_client": self.http_client}
|
|
self.client = openai.AzureOpenAI(
|
|
**client_params,
|
|
**sync_specific, # type: ignore[arg-type]
|
|
).completions
|
|
if not self.async_client:
|
|
async_specific = {"http_client": self.http_async_client}
|
|
|
|
if self.azure_ad_async_token_provider:
|
|
client_params["azure_ad_token_provider"] = (
|
|
self.azure_ad_async_token_provider
|
|
)
|
|
|
|
self.async_client = openai.AsyncAzureOpenAI(
|
|
**client_params,
|
|
**async_specific, # type: ignore[arg-type]
|
|
).completions
|
|
|
|
return self
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
return {
|
|
**{"deployment_name": self.deployment_name},
|
|
**super()._identifying_params,
|
|
}
|
|
|
|
@property
|
|
def _invocation_params(self) -> dict[str, Any]:
|
|
openai_params = {"model": self.deployment_name}
|
|
return {**openai_params, **super()._invocation_params}
|
|
|
|
def _get_ls_params(
|
|
self, stop: Optional[list[str]] = None, **kwargs: Any
|
|
) -> LangSmithParams:
|
|
"""Get standard params for tracing."""
|
|
params = super()._get_ls_params(stop=stop, **kwargs)
|
|
invocation_params = self._invocation_params
|
|
params["ls_provider"] = "azure"
|
|
if model_name := invocation_params.get("model"):
|
|
params["ls_model_name"] = model_name
|
|
return params
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "azure"
|
|
|
|
@property
|
|
def lc_attributes(self) -> dict[str, Any]:
|
|
return {
|
|
"openai_api_type": self.openai_api_type,
|
|
"openai_api_version": self.openai_api_version,
|
|
}
|