openai[minor]: change to secretstr (#16803)

This commit is contained in:
Erick Friis
2024-01-30 15:49:56 -08:00
committed by GitHub
parent bf9068516e
commit bb3b6bde33
10 changed files with 162 additions and 63 deletions

View File

@@ -2,18 +2,11 @@ from __future__ import annotations
import logging
import os
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Union,
)
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import openai
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_openai.llms.base import BaseOpenAI
@@ -52,9 +45,9 @@ class AzureOpenAI(BaseOpenAI):
"""
openai_api_version: str = Field(default="", alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Union[str, None] = None
azure_ad_token: Optional[SecretStr] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
@@ -92,17 +85,21 @@ class AzureOpenAI(BaseOpenAI):
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
values["openai_api_key"] = (
openai_api_key = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN")
values["azure_ad_token"] = (
convert_to_secret_str(azure_ad_token) if azure_ad_token else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
@@ -150,8 +147,12 @@ class AzureOpenAI(BaseOpenAI):
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["deployment_name"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"azure_ad_token": values["azure_ad_token"].get_secret_value()
if values["azure_ad_token"]
else None,
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],

View File

@@ -27,8 +27,12 @@ from langchain_core.callbacks import (
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__)
@@ -104,10 +108,7 @@ class BaseOpenAI(BaseLLM):
"""Generates best_of completions server-side and returns the "best"."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
# When updating this to use a SecretStr
# Check for classes that derive from this class (as some of them
# may assume openai_api_key is a str)
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service
@@ -175,9 +176,12 @@ class BaseOpenAI(BaseLLM):
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
values["openai_api_key"] = get_from_dict_or_env(
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
@@ -194,7 +198,9 @@ class BaseOpenAI(BaseLLM):
)
client_params = {
"api_key": values["openai_api_key"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],