openai[minor]: change to secretstr (#16803)

This commit is contained in:
Erick Friis
2024-01-30 15:49:56 -08:00
committed by GitHub
parent bf9068516e
commit bb3b6bde33
10 changed files with 162 additions and 63 deletions

View File

@@ -22,8 +22,18 @@ import numpy as np
import openai
import tiktoken
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
logger = logging.getLogger(__name__)
@@ -70,7 +80,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
openai_proxy: Optional[str] = None
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
@@ -152,9 +162,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
@@ -196,7 +209,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"please use the `AzureOpenAIEmbeddings` class."
)
client_params = {
"api_key": values["openai_api_key"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],