mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 19:03:25 +00:00
- This pull request includes various changes to add a `user_agent` parameter to Azure OpenAI, Azure Search and Whisper in the Community and Partner packages. This helps in identifying the source of API requests so we can better track usage and help support the community better. I will also be adding the user_agent to the new `langchain-azure` repo as well. - No issue connected or updated dependencies. - Utilises existing tests and docs --------- Co-authored-by: Erick Friis <erick@langchain.dev>
217 lines
8.2 KiB
Python
217 lines
8.2 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union
|
|
|
|
import openai
|
|
from langchain_core.language_models import LangSmithParams
|
|
from langchain_core.utils import from_env, secret_from_env
|
|
from pydantic import Field, SecretStr, model_validator
|
|
from typing_extensions import Self, cast
|
|
|
|
from langchain_openai.llms.base import BaseOpenAI
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AzureOpenAI(BaseOpenAI):
|
|
"""Azure-specific OpenAI large language models.
|
|
|
|
To use, you should have the ``openai`` python package installed, and the
|
|
environment variable ``OPENAI_API_KEY`` set with your API key.
|
|
|
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
|
in, even if not explicitly saved on this class.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_openai import AzureOpenAI
|
|
|
|
openai = AzureOpenAI(model_name="gpt-3.5-turbo-instruct")
|
|
"""
|
|
|
|
azure_endpoint: Optional[str] = Field(
|
|
default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None)
|
|
)
|
|
"""Your Azure endpoint, including the resource.
|
|
|
|
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
|
|
|
|
Example: `https://example-resource.azure.openai.com/`
|
|
"""
|
|
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
|
|
"""A model deployment.
|
|
|
|
If given sets the base client URL to include `/deployments/{azure_deployment}`.
|
|
Note: this means you won't be able to use non-deployment endpoints.
|
|
"""
|
|
openai_api_version: Optional[str] = Field(
|
|
alias="api_version",
|
|
default_factory=from_env("OPENAI_API_VERSION", default=None),
|
|
)
|
|
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
|
# Check OPENAI_KEY for backwards compatibility.
|
|
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
|
|
# other forms of azure credentials.
|
|
openai_api_key: Optional[SecretStr] = Field(
|
|
alias="api_key",
|
|
default_factory=secret_from_env(
|
|
["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None
|
|
),
|
|
)
|
|
azure_ad_token: Optional[SecretStr] = Field(
|
|
default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None)
|
|
)
|
|
"""Your Azure Active Directory token.
|
|
|
|
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
|
|
|
|
For more:
|
|
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
|
|
"""
|
|
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
|
"""A function that returns an Azure Active Directory token.
|
|
|
|
Will be invoked on every sync request. For async requests,
|
|
will be invoked if `azure_ad_async_token_provider` is not provided.
|
|
"""
|
|
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
|
"""A function that returns an Azure Active Directory token.
|
|
|
|
Will be invoked on every async request.
|
|
"""
|
|
openai_api_type: Optional[str] = Field(
|
|
default_factory=from_env("OPENAI_API_TYPE", default="azure")
|
|
)
|
|
"""Legacy, for openai<1.0.0 support."""
|
|
validate_base_url: bool = True
|
|
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
|
|
infer if it is a base_url or azure_endpoint and update accordingly.
|
|
"""
|
|
|
|
@classmethod
|
|
def get_lc_namespace(cls) -> List[str]:
|
|
"""Get the namespace of the langchain object."""
|
|
return ["langchain", "llms", "openai"]
|
|
|
|
@property
|
|
def lc_secrets(self) -> Dict[str, str]:
|
|
return {
|
|
"openai_api_key": "AZURE_OPENAI_API_KEY",
|
|
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
|
|
}
|
|
|
|
@classmethod
|
|
def is_lc_serializable(cls) -> bool:
|
|
"""Return whether this model can be serialized by Langchain."""
|
|
return True
|
|
|
|
@model_validator(mode="after")
|
|
def validate_environment(self) -> Self:
|
|
"""Validate that api key and python package exists in environment."""
|
|
if self.n < 1:
|
|
raise ValueError("n must be at least 1.")
|
|
if self.streaming and self.n > 1:
|
|
raise ValueError("Cannot stream results when n > 1.")
|
|
if self.streaming and self.best_of > 1:
|
|
raise ValueError("Cannot stream results when best_of > 1.")
|
|
# For backwards compatibility. Before openai v1, no distinction was made
|
|
# between azure_endpoint and base_url (openai_api_base).
|
|
openai_api_base = self.openai_api_base
|
|
if openai_api_base and self.validate_base_url:
|
|
if "/openai" not in openai_api_base:
|
|
self.openai_api_base = (
|
|
cast(str, self.openai_api_base).rstrip("/") + "/openai"
|
|
)
|
|
raise ValueError(
|
|
"As of openai>=1.0.0, Azure endpoints should be specified via "
|
|
"the `azure_endpoint` param not `openai_api_base` "
|
|
"(or alias `base_url`)."
|
|
)
|
|
if self.deployment_name:
|
|
raise ValueError(
|
|
"As of openai>=1.0.0, if `deployment_name` (or alias "
|
|
"`azure_deployment`) is specified then "
|
|
"`openai_api_base` (or alias `base_url`) should not be. "
|
|
"Instead use `deployment_name` (or alias `azure_deployment`) "
|
|
"and `azure_endpoint`."
|
|
)
|
|
self.deployment_name = None
|
|
client_params: dict = {
|
|
"api_version": self.openai_api_version,
|
|
"azure_endpoint": self.azure_endpoint,
|
|
"azure_deployment": self.deployment_name,
|
|
"api_key": self.openai_api_key.get_secret_value()
|
|
if self.openai_api_key
|
|
else None,
|
|
"azure_ad_token": self.azure_ad_token.get_secret_value()
|
|
if self.azure_ad_token
|
|
else None,
|
|
"azure_ad_token_provider": self.azure_ad_token_provider,
|
|
"organization": self.openai_organization,
|
|
"base_url": self.openai_api_base,
|
|
"timeout": self.request_timeout,
|
|
"max_retries": self.max_retries,
|
|
"default_headers": {
|
|
**(self.default_headers or {}),
|
|
"User-Agent": "langchain-partner-python-azure-openai",
|
|
},
|
|
"default_query": self.default_query,
|
|
}
|
|
if not self.client:
|
|
sync_specific = {"http_client": self.http_client}
|
|
self.client = openai.AzureOpenAI(
|
|
**client_params,
|
|
**sync_specific, # type: ignore[arg-type]
|
|
).completions
|
|
if not self.async_client:
|
|
async_specific = {"http_client": self.http_async_client}
|
|
|
|
if self.azure_ad_async_token_provider:
|
|
client_params["azure_ad_token_provider"] = (
|
|
self.azure_ad_async_token_provider
|
|
)
|
|
|
|
self.async_client = openai.AsyncAzureOpenAI(
|
|
**client_params,
|
|
**async_specific, # type: ignore[arg-type]
|
|
).completions
|
|
|
|
return self
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
return {
|
|
**{"deployment_name": self.deployment_name},
|
|
**super()._identifying_params,
|
|
}
|
|
|
|
@property
|
|
def _invocation_params(self) -> Dict[str, Any]:
|
|
openai_params = {"model": self.deployment_name}
|
|
return {**openai_params, **super()._invocation_params}
|
|
|
|
def _get_ls_params(
|
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
|
) -> LangSmithParams:
|
|
"""Get standard params for tracing."""
|
|
params = super()._get_ls_params(stop=stop, **kwargs)
|
|
invocation_params = self._invocation_params
|
|
params["ls_provider"] = "azure"
|
|
if model_name := invocation_params.get("model"):
|
|
params["ls_model_name"] = model_name
|
|
return params
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "azure"
|
|
|
|
@property
|
|
def lc_attributes(self) -> Dict[str, Any]:
|
|
return {
|
|
"openai_api_type": self.openai_api_type,
|
|
"openai_api_version": self.openai_api_version,
|
|
}
|