langchain/libs/partners/pinecone/langchain_pinecone/embeddings.py
Cory Waddingham ce9e9f9314
pinecone: Review pinecone tests (#29073)
Title: langchain-pinecone: improve test structure and async handling

Description: This PR improves the test infrastructure for the
langchain-pinecone package by:
1. Implementing LangChain's standard test patterns for embeddings
2. Adding comprehensive configuration testing
3. Improving async test coverage
4. Fixing integration test issues with namespaces and async markers

The changes make the tests more robust, maintainable, and aligned with
LangChain's testing standards while ensuring proper async behavior in
the embeddings implementation.

Key improvements:
- Added standard EmbeddingsTests implementation
- Split custom configuration tests into a separate test class
- Added proper async test coverage with pytest-asyncio
- Fixed namespace handling in vector store integration tests
- Improved test organization and documentation

Dependencies: None (uses existing test dependencies)

Tests and Documentation:
-  Added standard test implementation following LangChain's patterns
-  Added comprehensive unit tests for configuration and async behavior
-  All tests passing locally
- No documentation changes needed (internal test improvements only)

Twitter handle: N/A

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2025-01-07 11:46:30 -08:00

187 lines
6.2 KiB
Python

import logging
from typing import Any, Dict, Iterable, List, Optional
import aiohttp
from langchain_core.embeddings import Embeddings
from langchain_core.utils import secret_from_env
from pinecone import Pinecone as PineconeClient # type: ignore[import-untyped]
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import Self
logger = logging.getLogger(__name__)
DEFAULT_BATCH_SIZE = 64
class PineconeEmbeddings(BaseModel, Embeddings):
"""PineconeEmbeddings embedding model.
Example:
.. code-block:: python
from langchain_pinecone import PineconeEmbeddings
model = PineconeEmbeddings(model="multilingual-e5-large")
"""
# Clients
_client: PineconeClient = PrivateAttr(default=None)
_async_client: Optional[aiohttp.ClientSession] = PrivateAttr(default=None)
model: str
"""Model to use for example 'multilingual-e5-large'."""
# Config
batch_size: Optional[int] = None
"""Batch size for embedding documents."""
query_params: Dict = Field(default_factory=dict)
"""Parameters for embedding query."""
document_params: Dict = Field(default_factory=dict)
"""Parameters for embedding document"""
#
dimension: Optional[int] = None
#
show_progress_bar: bool = False
pinecone_api_key: SecretStr = Field(
default_factory=secret_from_env(
"PINECONE_API_KEY",
error_message="Pinecone API key not found. Please set the PINECONE_API_KEY "
"environment variable or pass it via `pinecone_api_key`.",
),
alias="api_key",
)
"""Pinecone API key.
If not provided, will look for the PINECONE_API_KEY environment variable."""
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
protected_namespaces=(),
)
@property
def async_client(self) -> aiohttp.ClientSession:
"""Lazily initialize the async client."""
if self._async_client is None:
self._async_client = aiohttp.ClientSession(
headers={
"Api-Key": self.pinecone_api_key.get_secret_value(),
"Content-Type": "application/json",
"X-Pinecone-API-Version": "2024-10",
}
)
return self._async_client
@model_validator(mode="before")
@classmethod
def set_default_config(cls, values: dict) -> Any:
"""Set default configuration based on model."""
default_config_map = {
"multilingual-e5-large": {
"batch_size": 96,
"query_params": {"input_type": "query", "truncation": "END"},
"document_params": {"input_type": "passage", "truncation": "END"},
"dimension": 1024,
}
}
model = values.get("model")
if model in default_config_map:
config = default_config_map[model]
for key, value in config.items():
if key not in values:
values[key] = value
return values
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that Pinecone version and credentials exist in environment."""
api_key_str = self.pinecone_api_key.get_secret_value()
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
self._client = client
# Ensure async_client is lazily initialized
return self
def _get_batch_iterator(self, texts: List[str]) -> Iterable:
if self.batch_size is None:
batch_size = DEFAULT_BATCH_SIZE
else:
batch_size = self.batch_size
if self.show_progress_bar:
try:
from tqdm.auto import tqdm # type: ignore
except ImportError as e:
raise ImportError(
"Must have tqdm installed if `show_progress_bar` is set to True. "
"Please install with `pip install tqdm`."
) from e
_iter = tqdm(range(0, len(texts), batch_size))
else:
_iter = range(0, len(texts), batch_size)
return _iter
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
embeddings: List[List[float]] = []
_iter = self._get_batch_iterator(texts)
for i in _iter:
response = self._client.inference.embed(
model=self.model,
parameters=self.document_params,
inputs=texts[i : i + self.batch_size],
)
embeddings.extend([r["values"] for r in response])
return embeddings
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings: List[List[float]] = []
_iter = self._get_batch_iterator(texts)
for i in _iter:
response = await self._aembed_texts(
model=self.model,
parameters=self.document_params,
texts=texts[i : i + self.batch_size],
)
embeddings.extend([r["values"] for r in response["data"]])
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self._client.inference.embed(
model=self.model, parameters=self.query_params, inputs=[text]
)[0]["values"]
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronously embed query text."""
response = await self._aembed_texts(
model=self.model,
parameters=self.document_params,
texts=[text],
)
return response["data"][0]["values"]
async def _aembed_texts(
self, texts: List[str], model: str, parameters: dict
) -> Dict:
data = {
"model": model,
"inputs": [{"text": text} for text in texts],
"parameters": parameters,
}
async with self.async_client.post(
"https://api.pinecone.io/embed", json=data
) as response:
response_data = await response.json(content_type=None)
return response_data