Use Embeddings in pinecone (#8982)

cc @eyurtsev @olivier-lacroix @jamescalam 

redo of #2741
This commit is contained in:
Bagatur 2023-08-10 14:22:41 -07:00 committed by GitHub
parent 8eea46ed0e
commit 16bd328aab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,8 @@ from __future__ import annotations
import logging import logging
import uuid import uuid
from typing import Any, Callable, Iterable, List, Optional, Tuple import warnings
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
@ -38,7 +39,7 @@ class Pinecone(VectorStore):
def __init__( def __init__(
self, self,
index: Any, index: Any,
embedding_function: Callable, embedding: Union[Embeddings, Callable],
text_key: str, text_key: str,
namespace: Optional[str] = None, namespace: Optional[str] = None,
distance_strategy: Optional[DistanceStrategy] = DistanceStrategy.COSINE, distance_strategy: Optional[DistanceStrategy] = DistanceStrategy.COSINE,
@ -47,7 +48,7 @@ class Pinecone(VectorStore):
try: try:
import pinecone import pinecone
except ImportError: except ImportError:
raise ValueError( raise ImportError(
"Could not import pinecone python package. " "Could not import pinecone python package. "
"Please install it with `pip install pinecone-client`." "Please install it with `pip install pinecone-client`."
) )
@ -56,17 +57,36 @@ class Pinecone(VectorStore):
f"client should be an instance of pinecone.index.Index, " f"client should be an instance of pinecone.index.Index, "
f"got {type(index)}" f"got {type(index)}"
) )
if not isinstance(embedding, Embeddings):
warnings.warn(
"Passing in `embedding` as a Callable is deprecated. Please pass in an"
" Embeddings object instead."
)
self._index = index self._index = index
self._embedding_function = embedding_function self._embedding = embedding
self._text_key = text_key self._text_key = text_key
self._namespace = namespace self._namespace = namespace
self.distance_strategy = distance_strategy self.distance_strategy = distance_strategy
@property @property
def embeddings(self) -> Optional[Embeddings]: def embeddings(self) -> Optional[Embeddings]:
# TODO: Accept this object directly """Access the query embedding object if available."""
if isinstance(self._embedding, Embeddings):
return self._embedding
return None return None
def _embed_documents(self, texts: Iterable[str]) -> List[List[float]]:
"""Embed search docs."""
if isinstance(self._embedding, Embeddings):
return self._embedding.embed_documents(list(texts))
return [self._embedding(t) for t in texts]
def _embed_query(self, text: str) -> List[float]:
"""Embed query text."""
if isinstance(self._embedding, Embeddings):
return self._embedding.embed_query(text)
return self._embedding(text)
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
@ -93,8 +113,8 @@ class Pinecone(VectorStore):
# Embed and create the documents # Embed and create the documents
docs = [] docs = []
ids = ids or [str(uuid.uuid4()) for _ in texts] ids = ids or [str(uuid.uuid4()) for _ in texts]
for i, text in enumerate(texts): embeddings = self._embed_documents(texts)
embedding = self._embedding_function(text) for i, (text, embedding) in enumerate(zip(texts, embeddings)):
metadata = metadatas[i] if metadatas else {} metadata = metadatas[i] if metadatas else {}
metadata[self._text_key] = text metadata[self._text_key] = text
docs.append((ids[i], embedding, metadata)) docs.append((ids[i], embedding, metadata))
@ -124,7 +144,7 @@ class Pinecone(VectorStore):
""" """
if namespace is None: if namespace is None:
namespace = self._namespace namespace = self._namespace
query_obj = self._embedding_function(query) query_obj = self._embed_query(query)
docs = [] docs = []
results = self._index.query( results = self._index.query(
[query_obj], [query_obj],
@ -265,7 +285,7 @@ class Pinecone(VectorStore):
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
embedding = self._embedding_function(query) embedding = self._embed_query(query)
return self.max_marginal_relevance_search_by_vector( return self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult, filter, namespace embedding, k, fetch_k, lambda_mult, filter, namespace
) )
@ -356,7 +376,7 @@ class Pinecone(VectorStore):
# upsert to Pinecone # upsert to Pinecone
_upsert_kwargs = upsert_kwargs or {} _upsert_kwargs = upsert_kwargs or {}
index.upsert(vectors=list(to_upsert), namespace=namespace, **_upsert_kwargs) index.upsert(vectors=list(to_upsert), namespace=namespace, **_upsert_kwargs)
return cls(index, embedding.embed_query, text_key, namespace, **kwargs) return cls(index, embedding, text_key, namespace, **kwargs)
@classmethod @classmethod
def from_existing_index( def from_existing_index(
@ -375,9 +395,7 @@ class Pinecone(VectorStore):
"Please install it with `pip install pinecone-client`." "Please install it with `pip install pinecone-client`."
) )
return cls( return cls(pinecone.Index(index_name), embedding, text_key, namespace)
pinecone.Index(index_name), embedding.embed_query, text_key, namespace
)
def delete( def delete(
self, self,