mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 20:16:52 +00:00
pinecone: Add embedding Inference Support (#24515)
**Description** Add support for Pinecone hosted embedding models as `PineconeEmbeddings`. Replacement for #22890 **Dependencies** Add `aiohttp` to support async embeddings call against REST directly - [x] **Add tests and docs**: If you're adding a new integration, please include Added `docs/docs/integrations/text_embedding/pinecone.ipynb` - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Twitter: `gdjdg17` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
181
libs/partners/pinecone/langchain_pinecone/embeddings.py
Normal file
181
libs/partners/pinecone/langchain_pinecone/embeddings.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from pinecone import Pinecone as PineconeClient # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_BATCH_SIZE = 64
|
||||
|
||||
|
||||
class PineconeEmbeddings(BaseModel, Embeddings):
|
||||
"""PineconeEmbeddings embedding model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_pinecone import PineconeEmbeddings
|
||||
|
||||
model = PineconeEmbeddings(model="multilingual-e5-large")
|
||||
"""
|
||||
|
||||
# Clients
|
||||
_client: PineconeClient = Field(default=None, exclude=True)
|
||||
_async_client: aiohttp.ClientSession = Field(default=None, exclude=True)
|
||||
model: str
|
||||
"""Model to use for example 'multilingual-e5-large'."""
|
||||
# Config
|
||||
batch_size: Optional[int] = None
|
||||
"""Batch size for embedding documents."""
|
||||
query_params: Dict = Field(default_factory=dict)
|
||||
"""Parameters for embedding query."""
|
||||
document_params: Dict = Field(default_factory=dict)
|
||||
"""Parameters for embedding document"""
|
||||
#
|
||||
dimension: Optional[int] = None
|
||||
#
|
||||
show_progress_bar: bool = False
|
||||
pinecone_api_key: Optional[SecretStr] = None
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_default_config(cls, values: dict) -> dict:
|
||||
"""Set default configuration based on model."""
|
||||
default_config_map = {
|
||||
"multilingual-e5-large": {
|
||||
"batch_size": 96,
|
||||
"query_params": {"input_type": "query", "truncation": "END"},
|
||||
"document_params": {"input_type": "passage", "truncation": "END"},
|
||||
"dimension": 1024,
|
||||
}
|
||||
}
|
||||
model = values.get("model")
|
||||
if model in default_config_map:
|
||||
config = default_config_map[model]
|
||||
for key, value in config.items():
|
||||
if key not in values:
|
||||
values[key] = value
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
"""Validate that Pinecone version and credentials exist in environment."""
|
||||
|
||||
pinecone_api_key = values.get("pinecone_api_key") or os.getenv(
|
||||
"PINECONE_API_KEY", None
|
||||
)
|
||||
if pinecone_api_key:
|
||||
api_key_secretstr = convert_to_secret_str(pinecone_api_key)
|
||||
values["pinecone_api_key"] = api_key_secretstr
|
||||
|
||||
api_key_str = api_key_secretstr.get_secret_value()
|
||||
else:
|
||||
api_key_str = None
|
||||
if api_key_str is None:
|
||||
raise ValueError(
|
||||
"Pinecone API key not found. Please set the PINECONE_API_KEY "
|
||||
"environment variable or pass it via `pinecone_api_key`."
|
||||
)
|
||||
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
|
||||
values["_client"] = client
|
||||
|
||||
# initialize async client
|
||||
if not values.get("_async_client"):
|
||||
values["_async_client"] = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Api-Key": api_key_str,
|
||||
"Content-Type": "application/json",
|
||||
"X-Pinecone-API-Version": "2024-07",
|
||||
}
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_batch_iterator(self, texts: List[str]) -> Iterable:
|
||||
if self.batch_size is None:
|
||||
batch_size = DEFAULT_BATCH_SIZE
|
||||
else:
|
||||
batch_size = self.batch_size
|
||||
|
||||
if self.show_progress_bar:
|
||||
try:
|
||||
from tqdm.auto import tqdm # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have tqdm installed if `show_progress_bar` is set to True. "
|
||||
"Please install with `pip install tqdm`."
|
||||
) from e
|
||||
|
||||
_iter = tqdm(range(0, len(texts), batch_size))
|
||||
else:
|
||||
_iter = range(0, len(texts), batch_size)
|
||||
|
||||
return _iter
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
embeddings: List[List[float]] = []
|
||||
|
||||
_iter = self._get_batch_iterator(texts)
|
||||
for i in _iter:
|
||||
response = self._client.inference.embed(
|
||||
model=self.model,
|
||||
parameters=self.document_params,
|
||||
inputs=texts[i : i + self.batch_size],
|
||||
)
|
||||
embeddings.extend([r["values"] for r in response])
|
||||
|
||||
return embeddings
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings: List[List[float]] = []
|
||||
_iter = self._get_batch_iterator(texts)
|
||||
for i in _iter:
|
||||
response = await self._aembed_texts(
|
||||
model=self.model,
|
||||
parameters=self.document_params,
|
||||
texts=texts[i : i + self.batch_size],
|
||||
)
|
||||
embeddings.extend([r["values"] for r in response["data"]])
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self._client.inference.embed(
|
||||
model=self.model, parameters=self.query_params, inputs=[text]
|
||||
)[0]["values"]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronously embed query text."""
|
||||
response = await self._aembed_texts(
|
||||
model=self.model,
|
||||
parameters=self.document_params,
|
||||
texts=[text],
|
||||
)
|
||||
return response["data"][0]["values"]
|
||||
|
||||
async def _aembed_texts(
|
||||
self, texts: List[str], model: str, parameters: dict
|
||||
) -> Dict:
|
||||
data = {
|
||||
"model": model,
|
||||
"inputs": [{"text": text} for text in texts],
|
||||
"parameters": parameters,
|
||||
}
|
||||
async with self._async_client.post(
|
||||
"https://api.pinecone.io/embed", json=data
|
||||
) as response:
|
||||
response_data = await response.json(content_type=None)
|
||||
return response_data
|
Reference in New Issue
Block a user