diff --git a/docs/modules/indexes/retrievers/examples/pinecone_hybrid_search.ipynb b/docs/modules/indexes/retrievers/examples/pinecone_hybrid_search.ipynb index 1e07b4a894a..9d3fa491c41 100644 --- a/docs/modules/indexes/retrievers/examples/pinecone_hybrid_search.ipynb +++ b/docs/modules/indexes/retrievers/examples/pinecone_hybrid_search.ipynb @@ -24,7 +24,7 @@ "metadata": {}, "outputs": [], "source": [ - "#!pip install pinecone-client" + "#!pip install pinecone-client pinecone-text" ] }, { diff --git a/langchain/retrievers/pinecone_hybrid_search.py b/langchain/retrievers/pinecone_hybrid_search.py index c4ad39e5f2b..bd04a296a7a 100644 --- a/langchain/retrievers/pinecone_hybrid_search.py +++ b/langchain/retrievers/pinecone_hybrid_search.py @@ -18,6 +18,7 @@ def create_index( embeddings: Embeddings, sparse_encoder: Any, ids: Optional[List[str]] = None, + metadatas: Optional[List[dict]] = None, ) -> None: batch_size = 32 _iterator = range(0, len(contexts), batch_size) @@ -38,8 +39,15 @@ def create_index( # extract batch context_batch = contexts[i:i_end] batch_ids = ids[i:i_end] + metadata_batch = ( + metadatas[i:i_end] if metadatas else [{} for _ in context_batch] + ) # add context passages as metadata - meta = [{"context": context} for context in context_batch] + meta = [ + {"context": context, **metadata} + for context, metadata in zip(context_batch, metadata_batch) + ] + # create dense vectors dense_embeds = embeddings.embed_documents(context_batch) # create sparse vectors @@ -78,8 +86,20 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel): extra = Extra.forbid arbitrary_types_allowed = True - def add_texts(self, texts: List[str], ids: Optional[List[str]] = None) -> None: - create_index(texts, self.index, self.embeddings, self.sparse_encoder, ids=ids) + def add_texts( + self, + texts: List[str], + ids: Optional[List[str]] = None, + metadatas: Optional[List[dict]] = None, + ) -> None: + create_index( + texts, + self.index, + self.embeddings, + self.sparse_encoder, + ids=ids, + metadatas=metadatas, + ) @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -114,7 +134,10 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel): ) final_result = [] for res in result["matches"]: - final_result.append(Document(page_content=res["metadata"]["context"])) + context = res["metadata"].pop("context") + final_result.append( + Document(page_content=context, metadata=res["metadata"]) + ) # return search results as json return final_result