mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-21 12:01:47 +00:00
Masking of API Key for GooseAI LLM (#12496)
Description: Add masking of API Key for GooseAI LLM when printed. Issue: https://github.com/langchain-ai/langchain/issues/12165 Dependencies: None Tag maintainer: @eyurtsev --------- Co-authored-by: Samad Koita <>
This commit is contained in:
parent
64c4a698a8
commit
d1fdcd4fcb
@ -1,14 +1,21 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
from langchain.pydantic_v1 import Extra, Field, SecretStr, root_validator
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
|
||||||
|
"""Convert a string to a SecretStr if needed."""
|
||||||
|
if isinstance(value, SecretStr):
|
||||||
|
return value
|
||||||
|
return SecretStr(value)
|
||||||
|
|
||||||
|
|
||||||
class GooseAI(LLM):
|
class GooseAI(LLM):
|
||||||
"""GooseAI large language models.
|
"""GooseAI large language models.
|
||||||
|
|
||||||
@ -60,7 +67,7 @@ class GooseAI(LLM):
|
|||||||
logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
|
logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
|
||||||
"""Adjust the probability of specific tokens being generated."""
|
"""Adjust the probability of specific tokens being generated."""
|
||||||
|
|
||||||
gooseai_api_key: Optional[str] = None
|
gooseai_api_key: Optional[SecretStr] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic config."""
|
"""Configuration for this pydantic config."""
|
||||||
@ -89,13 +96,14 @@ class GooseAI(LLM):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
gooseai_api_key = get_from_dict_or_env(
|
gooseai_api_key = _to_secret(
|
||||||
values, "gooseai_api_key", "GOOSEAI_API_KEY"
|
get_from_dict_or_env(values, "gooseai_api_key", "GOOSEAI_API_KEY")
|
||||||
)
|
)
|
||||||
|
values["gooseai_api_key"] = gooseai_api_key
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.api_key = gooseai_api_key
|
openai.api_key = gooseai_api_key.get_secret_value()
|
||||||
openai.api_base = "https://api.goose.ai/v1"
|
openai.api_base = "https://api.goose.ai/v1"
|
||||||
values["client"] = openai.Completion
|
values["client"] = openai.Completion
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
32
libs/langchain/tests/unit_tests/llms/test_gooseai.py
Normal file
32
libs/langchain/tests/unit_tests/llms/test_gooseai.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
"""Test GooseAI"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
|
from langchain.llms.gooseai import GooseAI
|
||||||
|
from langchain.pydantic_v1 import SecretStr
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_api_key_is_secret_string() -> None:
|
||||||
|
llm = GooseAI(gooseai_api_key="secret-api-key")
|
||||||
|
assert isinstance(llm.gooseai_api_key, SecretStr)
|
||||||
|
assert llm.gooseai_api_key.get_secret_value() == "secret-api-key"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_api_key_masked_when_passed_via_constructor() -> None:
|
||||||
|
llm = GooseAI(gooseai_api_key="secret-api-key")
|
||||||
|
assert str(llm.gooseai_api_key) == "**********"
|
||||||
|
assert "secret-api-key" not in repr(llm.gooseai_api_key)
|
||||||
|
assert "secret-api-key" not in repr(llm)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_api_key_masked_when_passed_from_env() -> None:
|
||||||
|
with MonkeyPatch.context() as mp:
|
||||||
|
mp.setenv("GOOSEAI_API_KEY", "secret-api-key")
|
||||||
|
llm = GooseAI()
|
||||||
|
assert str(llm.gooseai_api_key) == "**********"
|
||||||
|
assert "secret-api-key" not in repr(llm.gooseai_api_key)
|
||||||
|
assert "secret-api-key" not in repr(llm)
|
Loading…
Reference in New Issue
Block a user