diff --git a/libs/partners/together/langchain_together/chat_models.py b/libs/partners/together/langchain_together/chat_models.py index ee3473eba40..cb749fbaaee 100644 --- a/libs/partners/together/langchain_together/chat_models.py +++ b/libs/partners/together/langchain_together/chat_models.py @@ -1,6 +1,5 @@ """Wrapper around Together AI's Chat Completions API.""" -import os from typing import ( Any, Dict, @@ -12,8 +11,8 @@ import openai from langchain_core.language_models.chat_models import LangSmithParams from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import ( - convert_to_secret_str, - get_from_dict_or_env, + from_env, + secret_from_env, ) from langchain_openai.chat_models.base import BaseChatOpenAI @@ -311,13 +310,27 @@ class ChatTogether(BaseChatOpenAI): model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model") """Model name to use.""" - together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") - """Automatically inferred from env are `TOGETHER_API_KEY` if not provided.""" - together_api_base: Optional[str] = Field( - default="https://api.together.ai/v1/", alias="base_url" + together_api_key: Optional[SecretStr] = Field( + alias="api_key", + default_factory=secret_from_env("TOGETHER_API_KEY", default=None), + ) + """Together AI API key. + + Automatically read from env variable `TOGETHER_API_KEY` if not provided. + """ + together_api_base: str = Field( + default_factory=from_env( + "TOGETHER_API_BASE", default="https://api.together.ai/v1/" + ), + alias="base_url", ) - @root_validator() + class Config: + """Pydantic config.""" + + allow_population_by_field_name = True + + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: @@ -325,13 +338,6 @@ class ChatTogether(BaseChatOpenAI): if values["n"] > 1 and values["streaming"]: raise ValueError("n must be 1 when streaming.") - values["together_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY") - ) - values["together_api_base"] = values["together_api_base"] or os.getenv( - "TOGETHER_API_BASE" - ) - client_params = { "api_key": ( values["together_api_key"].get_secret_value() diff --git a/libs/partners/together/langchain_together/embeddings.py b/libs/partners/together/langchain_together/embeddings.py index 209537e241b..3757480d139 100644 --- a/libs/partners/together/langchain_together/embeddings.py +++ b/libs/partners/together/langchain_together/embeddings.py @@ -1,7 +1,6 @@ """Wrapper around Together AI's Embeddings API.""" import logging -import os import warnings from typing import ( Any, @@ -25,9 +24,9 @@ from langchain_core.pydantic_v1 import ( root_validator, ) from langchain_core.utils import ( - convert_to_secret_str, - get_from_dict_or_env, + from_env, get_pydantic_field_names, + secret_from_env, ) logger = logging.getLogger(__name__) @@ -115,10 +114,19 @@ class TogetherEmbeddings(BaseModel, Embeddings): Not yet supported. """ - together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") - """API Key for Solar API.""" + together_api_key: Optional[SecretStr] = Field( + alias="api_key", + default_factory=secret_from_env("TOGETHER_API_KEY", default=None), + ) + """Together AI API key. + + Automatically read from env variable `TOGETHER_API_KEY` if not provided. + """ together_api_base: str = Field( - default="https://api.together.ai/v1/", alias="base_url" + default_factory=from_env( + "TOGETHER_API_BASE", default="https://api.together.ai/v1/" + ), + alias="base_url", ) """Endpoint URL to use.""" embedding_ctx_length: int = 4096 @@ -198,18 +206,9 @@ class TogetherEmbeddings(BaseModel, Embeddings): values["model_kwargs"] = extra return values - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - together_api_key = get_from_dict_or_env( - values, "together_api_key", "TOGETHER_API_KEY" - ) - values["together_api_key"] = ( - convert_to_secret_str(together_api_key) if together_api_key else None - ) - values["together_api_base"] = values["together_api_base"] or os.getenv( - "TOGETHER_API_BASE" - ) + @root_validator(pre=False, skip_on_failure=True) + def post_init(cls, values: Dict) -> Dict: + """Logic that will post Pydantic initialization.""" client_params = { "api_key": ( values["together_api_key"].get_secret_value() diff --git a/libs/partners/together/langchain_together/llms.py b/libs/partners/together/langchain_together/llms.py index b19df54e545..4d78149d573 100644 --- a/libs/partners/together/langchain_together/llms.py +++ b/libs/partners/together/langchain_together/llms.py @@ -11,8 +11,10 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import ( + secret_from_env, +) logger = logging.getLogger(__name__) @@ -36,8 +38,14 @@ class Together(LLM): base_url: str = "https://api.together.ai/v1/completions" """Base completions API URL.""" - together_api_key: SecretStr - """Together AI API key. Get it here: https://api.together.ai/settings/api-keys""" + together_api_key: SecretStr = Field( + alias="api_key", + default_factory=secret_from_env("TOGETHER_API_KEY"), + ) + """Together AI API key. + + Automatically read from env variable `TOGETHER_API_KEY` if not provided. + """ model: str """Model name. Available models listed here: Base Models: https://docs.together.ai/docs/inference-models#language-models @@ -74,21 +82,11 @@ class Together(LLM): """Configuration for this pydantic object.""" extra = "forbid" + allow_population_by_field_name = True @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" - values["together_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY") - ) - return values - - @root_validator() - def validate_max_tokens(cls, values: Dict) -> Dict: - """The v1 completions endpoint, has max_tokens as required parameter. - - Set a default value and warn if the parameter is missing. - """ if values.get("max_tokens") is None: warnings.warn( "The completions endpoint, has 'max_tokens' as required argument. " diff --git a/libs/partners/together/tests/unit_tests/test_llms.py b/libs/partners/together/tests/unit_tests/test_llms.py index 79857be5560..e5caac22792 100644 --- a/libs/partners/together/tests/unit_tests/test_llms.py +++ b/libs/partners/together/tests/unit_tests/test_llms.py @@ -9,7 +9,7 @@ from langchain_together import Together def test_together_api_key_is_secret_string() -> None: """Test that the API key is stored as a SecretStr.""" llm = Together( - together_api_key="secret-api-key", # type: ignore[arg-type] + together_api_key="secret-api-key", # type: ignore[call-arg] model="togethercomputer/RedPajama-INCITE-7B-Base", temperature=0.2, max_tokens=250, @@ -38,7 +38,7 @@ def test_together_api_key_masked_when_passed_via_constructor( ) -> None: """Test that the API key is masked when passed via the constructor.""" llm = Together( - together_api_key="secret-api-key", # type: ignore[arg-type] + together_api_key="secret-api-key", # type: ignore[call-arg] model="togethercomputer/RedPajama-INCITE-7B-Base", temperature=0.2, max_tokens=250, @@ -52,7 +52,18 @@ def test_together_api_key_masked_when_passed_via_constructor( def test_together_uses_actual_secret_value_from_secretstr() -> None: """Test that the actual secret value is correctly retrieved.""" llm = Together( - together_api_key="secret-api-key", # type: ignore[arg-type] + together_api_key="secret-api-key", # type: ignore[call-arg] + model="togethercomputer/RedPajama-INCITE-7B-Base", + temperature=0.2, + max_tokens=250, + ) + assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key" + + +def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None: + """Test that the actual secret value is correctly retrieved.""" + llm = Together( + api_key="secret-api-key", # type: ignore[arg-type] model="togethercomputer/RedPajama-INCITE-7B-Base", temperature=0.2, max_tokens=250,