mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 04:50:37 +00:00
Backwards compatible change that converts pydantic extras to literals which is consistent with pydantic 2 usage. - fireworks - voyage ai - mistralai - mistral ai - together ai - huggigng face - pinecone
181 lines
6.0 KiB
Python
181 lines
6.0 KiB
Python
import logging
|
|
import os
|
|
from typing import Dict, Iterable, List, Optional
|
|
|
|
import aiohttp
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.pydantic_v1 import (
|
|
BaseModel,
|
|
Field,
|
|
SecretStr,
|
|
root_validator,
|
|
)
|
|
from langchain_core.utils import convert_to_secret_str
|
|
from pinecone import Pinecone as PineconeClient # type: ignore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_BATCH_SIZE = 64
|
|
|
|
|
|
class PineconeEmbeddings(BaseModel, Embeddings):
|
|
"""PineconeEmbeddings embedding model.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_pinecone import PineconeEmbeddings
|
|
|
|
model = PineconeEmbeddings(model="multilingual-e5-large")
|
|
"""
|
|
|
|
# Clients
|
|
_client: PineconeClient = Field(default=None, exclude=True)
|
|
_async_client: aiohttp.ClientSession = Field(default=None, exclude=True)
|
|
model: str
|
|
"""Model to use for example 'multilingual-e5-large'."""
|
|
# Config
|
|
batch_size: Optional[int] = None
|
|
"""Batch size for embedding documents."""
|
|
query_params: Dict = Field(default_factory=dict)
|
|
"""Parameters for embedding query."""
|
|
document_params: Dict = Field(default_factory=dict)
|
|
"""Parameters for embedding document"""
|
|
#
|
|
dimension: Optional[int] = None
|
|
#
|
|
show_progress_bar: bool = False
|
|
pinecone_api_key: Optional[SecretStr] = None
|
|
|
|
class Config:
|
|
extra = "forbid"
|
|
|
|
@root_validator(pre=True)
|
|
def set_default_config(cls, values: dict) -> dict:
|
|
"""Set default configuration based on model."""
|
|
default_config_map = {
|
|
"multilingual-e5-large": {
|
|
"batch_size": 96,
|
|
"query_params": {"input_type": "query", "truncation": "END"},
|
|
"document_params": {"input_type": "passage", "truncation": "END"},
|
|
"dimension": 1024,
|
|
}
|
|
}
|
|
model = values.get("model")
|
|
if model in default_config_map:
|
|
config = default_config_map[model]
|
|
for key, value in config.items():
|
|
if key not in values:
|
|
values[key] = value
|
|
return values
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: dict) -> dict:
|
|
"""Validate that Pinecone version and credentials exist in environment."""
|
|
|
|
pinecone_api_key = values.get("pinecone_api_key") or os.getenv(
|
|
"PINECONE_API_KEY", None
|
|
)
|
|
if pinecone_api_key:
|
|
api_key_secretstr = convert_to_secret_str(pinecone_api_key)
|
|
values["pinecone_api_key"] = api_key_secretstr
|
|
|
|
api_key_str = api_key_secretstr.get_secret_value()
|
|
else:
|
|
api_key_str = None
|
|
if api_key_str is None:
|
|
raise ValueError(
|
|
"Pinecone API key not found. Please set the PINECONE_API_KEY "
|
|
"environment variable or pass it via `pinecone_api_key`."
|
|
)
|
|
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
|
|
values["_client"] = client
|
|
|
|
# initialize async client
|
|
if not values.get("_async_client"):
|
|
values["_async_client"] = aiohttp.ClientSession(
|
|
headers={
|
|
"Api-Key": api_key_str,
|
|
"Content-Type": "application/json",
|
|
"X-Pinecone-API-Version": "2024-07",
|
|
}
|
|
)
|
|
return values
|
|
|
|
def _get_batch_iterator(self, texts: List[str]) -> Iterable:
|
|
if self.batch_size is None:
|
|
batch_size = DEFAULT_BATCH_SIZE
|
|
else:
|
|
batch_size = self.batch_size
|
|
|
|
if self.show_progress_bar:
|
|
try:
|
|
from tqdm.auto import tqdm # type: ignore
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Must have tqdm installed if `show_progress_bar` is set to True. "
|
|
"Please install with `pip install tqdm`."
|
|
) from e
|
|
|
|
_iter = tqdm(range(0, len(texts), batch_size))
|
|
else:
|
|
_iter = range(0, len(texts), batch_size)
|
|
|
|
return _iter
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Embed search docs."""
|
|
embeddings: List[List[float]] = []
|
|
|
|
_iter = self._get_batch_iterator(texts)
|
|
for i in _iter:
|
|
response = self._client.inference.embed(
|
|
model=self.model,
|
|
parameters=self.document_params,
|
|
inputs=texts[i : i + self.batch_size],
|
|
)
|
|
embeddings.extend([r["values"] for r in response])
|
|
|
|
return embeddings
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
embeddings: List[List[float]] = []
|
|
_iter = self._get_batch_iterator(texts)
|
|
for i in _iter:
|
|
response = await self._aembed_texts(
|
|
model=self.model,
|
|
parameters=self.document_params,
|
|
texts=texts[i : i + self.batch_size],
|
|
)
|
|
embeddings.extend([r["values"] for r in response["data"]])
|
|
return embeddings
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Embed query text."""
|
|
return self._client.inference.embed(
|
|
model=self.model, parameters=self.query_params, inputs=[text]
|
|
)[0]["values"]
|
|
|
|
async def aembed_query(self, text: str) -> List[float]:
|
|
"""Asynchronously embed query text."""
|
|
response = await self._aembed_texts(
|
|
model=self.model,
|
|
parameters=self.document_params,
|
|
texts=[text],
|
|
)
|
|
return response["data"][0]["values"]
|
|
|
|
async def _aembed_texts(
|
|
self, texts: List[str], model: str, parameters: dict
|
|
) -> Dict:
|
|
data = {
|
|
"model": model,
|
|
"inputs": [{"text": text} for text in texts],
|
|
"parameters": parameters,
|
|
}
|
|
async with self._async_client.post(
|
|
"https://api.pinecone.io/embed", json=data
|
|
) as response:
|
|
response_data = await response.json(content_type=None)
|
|
return response_data
|