diff --git a/libs/partners/pinecone/langchain_pinecone/embeddings.py b/libs/partners/pinecone/langchain_pinecone/embeddings.py index cad6b18d495..c0adecd1e86 100644 --- a/libs/partners/pinecone/langchain_pinecone/embeddings.py +++ b/libs/partners/pinecone/langchain_pinecone/embeddings.py @@ -1,5 +1,4 @@ import logging -import os from typing import Dict, Iterable, List, Optional import aiohttp @@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import ( SecretStr, root_validator, ) -from langchain_core.utils import convert_to_secret_str +from langchain_core.utils import secret_from_env from pinecone import Pinecone as PineconeClient # type: ignore logger = logging.getLogger(__name__) @@ -45,10 +44,21 @@ class PineconeEmbeddings(BaseModel, Embeddings): dimension: Optional[int] = None # show_progress_bar: bool = False - pinecone_api_key: Optional[SecretStr] = None + pinecone_api_key: Optional[SecretStr] = Field( + default_factory=secret_from_env( + "PINECONE_API_KEY", + error_message="Pinecone API key not found. Please set the PINECONE_API_KEY " + "environment variable or pass it via `pinecone_api_key`.", + ), + alias="api_key", + ) + """Pinecone API key. + + If not provided, will look for the PINECONE_API_KEY environment variable.""" class Config: extra = "forbid" + allow_population_by_field_name = True @root_validator(pre=True) def set_default_config(cls, values: dict) -> dict: @@ -69,25 +79,10 @@ class PineconeEmbeddings(BaseModel, Embeddings): values[key] = value return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: dict) -> dict: """Validate that Pinecone version and credentials exist in environment.""" - - pinecone_api_key = values.get("pinecone_api_key") or os.getenv( - "PINECONE_API_KEY", None - ) - if pinecone_api_key: - api_key_secretstr = convert_to_secret_str(pinecone_api_key) - values["pinecone_api_key"] = api_key_secretstr - - api_key_str = api_key_secretstr.get_secret_value() - else: - api_key_str = None - if api_key_str is None: - raise ValueError( - "Pinecone API key not found. Please set the PINECONE_API_KEY " - "environment variable or pass it via `pinecone_api_key`." - ) + api_key_str = values["pinecone_api_key"].get_secret_value() client = PineconeClient(api_key=api_key_str, source_tag="langchain") values["_client"] = client diff --git a/libs/partners/pinecone/tests/unit_tests/test_embeddings.py b/libs/partners/pinecone/tests/unit_tests/test_embeddings.py index 23d7b1df4c7..924b4e79662 100644 --- a/libs/partners/pinecone/tests/unit_tests/test_embeddings.py +++ b/libs/partners/pinecone/tests/unit_tests/test_embeddings.py @@ -7,10 +7,22 @@ MODEL_NAME = "multilingual-e5-large" def test_default_config() -> None: - e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME) + e = PineconeEmbeddings( + pinecone_api_key=API_KEY, # type: ignore[call-arg] + model=MODEL_NAME, + ) + assert e.batch_size == 96 + + +def test_default_config_with_api_key() -> None: + e = PineconeEmbeddings(api_key=API_KEY, model=MODEL_NAME) assert e.batch_size == 96 def test_custom_config() -> None: - e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME, batch_size=128) + e = PineconeEmbeddings( + pinecone_api_key=API_KEY, # type: ignore[call-arg] + model=MODEL_NAME, + batch_size=128, + ) assert e.batch_size == 128