From 16bd328aab52e25f4a22283916d3a16ec8bccd4b Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 10 Aug 2023 14:22:41 -0700 Subject: [PATCH] Use Embeddings in pinecone (#8982) cc @eyurtsev @olivier-lacroix @jamescalam redo of #2741 --- .../langchain/vectorstores/pinecone.py | 44 +++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/pinecone.py b/libs/langchain/langchain/vectorstores/pinecone.py index f9cc2303a60..df5b8684445 100644 --- a/libs/langchain/langchain/vectorstores/pinecone.py +++ b/libs/langchain/langchain/vectorstores/pinecone.py @@ -3,7 +3,8 @@ from __future__ import annotations import logging import uuid -from typing import Any, Callable, Iterable, List, Optional, Tuple +import warnings +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union import numpy as np @@ -38,7 +39,7 @@ class Pinecone(VectorStore): def __init__( self, index: Any, - embedding_function: Callable, + embedding: Union[Embeddings, Callable], text_key: str, namespace: Optional[str] = None, distance_strategy: Optional[DistanceStrategy] = DistanceStrategy.COSINE, @@ -47,7 +48,7 @@ class Pinecone(VectorStore): try: import pinecone except ImportError: - raise ValueError( + raise ImportError( "Could not import pinecone python package. " "Please install it with `pip install pinecone-client`." ) @@ -56,17 +57,36 @@ class Pinecone(VectorStore): f"client should be an instance of pinecone.index.Index, " f"got {type(index)}" ) + if not isinstance(embedding, Embeddings): + warnings.warn( + "Passing in `embedding` as a Callable is deprecated. Please pass in an" + " Embeddings object instead." + ) self._index = index - self._embedding_function = embedding_function + self._embedding = embedding self._text_key = text_key self._namespace = namespace self.distance_strategy = distance_strategy @property def embeddings(self) -> Optional[Embeddings]: - # TODO: Accept this object directly + """Access the query embedding object if available.""" + if isinstance(self._embedding, Embeddings): + return self._embedding return None + def _embed_documents(self, texts: Iterable[str]) -> List[List[float]]: + """Embed search docs.""" + if isinstance(self._embedding, Embeddings): + return self._embedding.embed_documents(list(texts)) + return [self._embedding(t) for t in texts] + + def _embed_query(self, text: str) -> List[float]: + """Embed query text.""" + if isinstance(self._embedding, Embeddings): + return self._embedding.embed_query(text) + return self._embedding(text) + def add_texts( self, texts: Iterable[str], @@ -93,8 +113,8 @@ class Pinecone(VectorStore): # Embed and create the documents docs = [] ids = ids or [str(uuid.uuid4()) for _ in texts] - for i, text in enumerate(texts): - embedding = self._embedding_function(text) + embeddings = self._embed_documents(texts) + for i, (text, embedding) in enumerate(zip(texts, embeddings)): metadata = metadatas[i] if metadatas else {} metadata[self._text_key] = text docs.append((ids[i], embedding, metadata)) @@ -124,7 +144,7 @@ class Pinecone(VectorStore): """ if namespace is None: namespace = self._namespace - query_obj = self._embedding_function(query) + query_obj = self._embed_query(query) docs = [] results = self._index.query( [query_obj], @@ -265,7 +285,7 @@ class Pinecone(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - embedding = self._embedding_function(query) + embedding = self._embed_query(query) return self.max_marginal_relevance_search_by_vector( embedding, k, fetch_k, lambda_mult, filter, namespace ) @@ -356,7 +376,7 @@ class Pinecone(VectorStore): # upsert to Pinecone _upsert_kwargs = upsert_kwargs or {} index.upsert(vectors=list(to_upsert), namespace=namespace, **_upsert_kwargs) - return cls(index, embedding.embed_query, text_key, namespace, **kwargs) + return cls(index, embedding, text_key, namespace, **kwargs) @classmethod def from_existing_index( @@ -375,9 +395,7 @@ class Pinecone(VectorStore): "Please install it with `pip install pinecone-client`." ) - return cls( - pinecone.Index(index_name), embedding.embed_query, text_key, namespace - ) + return cls(pinecone.Index(index_name), embedding, text_key, namespace) def delete( self,