1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-04 22:58:42 +00:00
langchain/libs/partners/pinecone/tests/unit_tests/test_embeddings.py
Chester Curme e3185a2212 update
2025-01-10 14:39:15 -05:00

84 lines
2.8 KiB
Python

from typing import Any, Type
from unittest.mock import patch
import aiohttp
import pytest
from langchain_core.utils import convert_to_secret_str
from langchain_tests.unit_tests.embeddings import EmbeddingsTests
from langchain_pinecone import PineconeEmbeddings
API_KEY = convert_to_secret_str("NOT_A_VALID_KEY")
MODEL_NAME = "multilingual-e5-large"
@pytest.fixture(autouse=True)
def mock_pinecone() -> Any:
"""Mock Pinecone client for all tests."""
with patch("langchain_pinecone.embeddings.PineconeClient") as mock:
yield mock
class TestPineconeEmbeddingsStandard(EmbeddingsTests):
"""Standard LangChain embeddings tests."""
@property
def embeddings_class(self) -> Type[PineconeEmbeddings]:
"""Get the class under test."""
return PineconeEmbeddings
@property
def embedding_model_params(self) -> dict:
"""Get the parameters for initializing the embeddings model."""
return {
"model": MODEL_NAME,
"pinecone_api_key": API_KEY,
}
class TestPineconeEmbeddingsConfig:
"""Additional configuration tests for PineconeEmbeddings."""
def test_default_config(self) -> None:
"""Test default configuration is set correctly."""
embeddings = PineconeEmbeddings(model=MODEL_NAME, pinecone_api_key=API_KEY) # type: ignore
assert embeddings.batch_size == 96
assert embeddings.query_params == {"input_type": "query", "truncation": "END"}
assert embeddings.document_params == {
"input_type": "passage",
"truncation": "END",
}
assert embeddings.dimension == 1024
def test_custom_config(self) -> None:
"""Test custom configuration overrides defaults."""
embeddings = PineconeEmbeddings(
model=MODEL_NAME,
api_key=API_KEY,
batch_size=128,
query_params={"custom": "param"},
document_params={"other": "param"},
)
assert embeddings.batch_size == 128
assert embeddings.query_params == {"custom": "param"}
assert embeddings.document_params == {"other": "param"}
@pytest.mark.asyncio
async def test_async_client_initialization(self) -> None:
"""Test async client is initialized correctly and only when needed."""
embeddings = PineconeEmbeddings(model=MODEL_NAME, api_key=API_KEY)
assert embeddings._async_client is None
# Access async_client property
client = await embeddings.get_async_client()
assert client is not None
assert isinstance(client, aiohttp.ClientSession)
# Ensure headers are set correctly
expected_headers = {
"Api-Key": API_KEY.get_secret_value(),
"Content-Type": "application/json",
"X-Pinecone-API-Version": "2024-10",
}
assert client._default_headers == expected_headers