From a578076aeaf7c5ddaa6f7a9f55c0897180d75443 Mon Sep 17 00:00:00 2001 From: David Norman Date: Tue, 28 Nov 2023 21:57:40 -0600 Subject: [PATCH] Mask api key for Together LLM (#13981) - **Description:** Add unit tests and mask api key for Together LLM - **Issue:** the issue https://github.com/langchain-ai/langchain/issues/12165 , - **Dependencies:** N/A - **Tag maintainer:** ?, - **Twitter handle:** N/A --------- Co-authored-by: Eugene Yurtsev --- libs/langchain/langchain/llms/together.py | 14 ++--- .../tests/unit_tests/llms/test_together.py | 61 +++++++++++++++++++ 2 files changed, 68 insertions(+), 7 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/llms/test_together.py diff --git a/libs/langchain/langchain/llms/together.py b/libs/langchain/langchain/llms/together.py index 46ada2e9d30..e37b6e7df9a 100644 --- a/libs/langchain/langchain/llms/together.py +++ b/libs/langchain/langchain/llms/together.py @@ -3,7 +3,7 @@ import logging from typing import Any, Dict, List, Optional from aiohttp import ClientSession -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -11,7 +11,7 @@ from langchain.callbacks.manager import ( ) from langchain.llms.base import LLM from langchain.utilities.requests import Requests -from langchain.utils import get_from_dict_or_env +from langchain.utils import convert_to_secret_str, get_from_dict_or_env logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ class Together(LLM): base_url: str = "https://api.together.xyz/inference" """Base inference API URL.""" - together_api_key: str + together_api_key: SecretStr """Together AI API key. Get it here: https://api.together.xyz/settings/api-keys""" model: str """Model name. Available models listed here: @@ -69,8 +69,8 @@ class Together(LLM): @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" - values["together_api_key"] = get_from_dict_or_env( - values, "together_api_key", "TOGETHER_API_KEY" + values["together_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY") ) return values @@ -116,7 +116,7 @@ class Together(LLM): """ headers = { - "Authorization": f"Bearer {self.together_api_key}", + "Authorization": f"Bearer {self.together_api_key.get_secret_value()}", "Content-Type": "application/json", } stop_to_use = stop[0] if stop and len(stop) == 1 else stop @@ -167,7 +167,7 @@ class Together(LLM): The string generated by the model. """ headers = { - "Authorization": f"Bearer {self.together_api_key}", + "Authorization": f"Bearer {self.together_api_key.get_secret_value()}", "Content-Type": "application/json", } stop_to_use = stop[0] if stop and len(stop) == 1 else stop diff --git a/libs/langchain/tests/unit_tests/llms/test_together.py b/libs/langchain/tests/unit_tests/llms/test_together.py new file mode 100644 index 00000000000..e8624464d78 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_together.py @@ -0,0 +1,61 @@ +"""Test Together LLM""" +from typing import cast + +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch + +from langchain.llms.together import Together + + +def test_together_api_key_is_secret_string() -> None: + """Test that the API key is stored as a SecretStr.""" + llm = Together( + together_api_key="secret-api-key", + model="togethercomputer/RedPajama-INCITE-7B-Base", + temperature=0.2, + max_tokens=250, + ) + assert isinstance(llm.together_api_key, SecretStr) + + +def test_together_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test that the API key is masked when passed from an environment variable.""" + monkeypatch.setenv("TOGETHER_API_KEY", "secret-api-key") + llm = Together( + model="togethercomputer/RedPajama-INCITE-7B-Base", + temperature=0.2, + max_tokens=250, + ) + print(llm.together_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_together_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test that the API key is masked when passed via the constructor.""" + llm = Together( + together_api_key="secret-api-key", + model="togethercomputer/RedPajama-INCITE-7B-Base", + temperature=0.2, + max_tokens=250, + ) + print(llm.together_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_together_uses_actual_secret_value_from_secretstr() -> None: + """Test that the actual secret value is correctly retrieved.""" + llm = Together( + together_api_key="secret-api-key", + model="togethercomputer/RedPajama-INCITE-7B-Base", + temperature=0.2, + max_tokens=250, + ) + assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"