Compare commits

...

3 Commits

Author SHA1 Message Date
Harrison Chase
2d2c4e8dd0 cr 2023-04-11 21:08:30 -07:00
Harrison Chase
251b38e3f2 cr 2023-04-11 21:07:29 -07:00
Olivier Lacroix
4dc49a8699 Batch embedding in pinecone add_texts method (#2657)
Hello there,

I noticed the `add_texts` method took a fair while for the Pinecone
VectorStore. This happened because embeddings were computed one by one.
This PR fixes it by calling `Embeddings.embed_documents` on all texts at
once.

I also took the liberty to harmonize initialization of embeddings to the
Embeddings class.

Cheers,

Olivier
2023-04-11 20:53:53 -07:00

View File

@@ -2,7 +2,8 @@
from __future__ import annotations
import uuid
from typing import Any, Callable, Iterable, List, Optional, Tuple
import warnings
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
@@ -26,13 +27,13 @@ class Pinecone(VectorStore):
pinecone.init(api_key="***", environment="...")
index = pinecone.Index("langchain-demo")
embeddings = OpenAIEmbeddings()
vectorstore = Pinecone(index, embeddings.embed_query, "text")
vectorstore = Pinecone(index, embeddings, "text")
"""
def __init__(
self,
index: Any,
embedding_function: Callable,
embeddings: Union[Embeddings, Callable],
text_key: str,
namespace: Optional[str] = None,
):
@@ -50,7 +51,17 @@ class Pinecone(VectorStore):
f"got {type(index)}"
)
self._index = index
self._embedding_function = embedding_function
if isinstance(embeddings, Embeddings):
self._embeddings = embeddings
else:
# This is for backwards compatibility issues. Previously,
# embeddings.embed_query was passed in, not the whole class
warnings.warn(
"passing a function as embeddings is deprecated, "
"you should pass an Embedding object directly. "
"If this throws an error, that is why."
)
self._embeddings = embeddings.__self__ # type: ignore
self._text_key = text_key
self._namespace = namespace
@@ -78,13 +89,14 @@ class Pinecone(VectorStore):
if namespace is None:
namespace = self._namespace
# Embed and create the documents
docs = []
ids = ids or [str(uuid.uuid4()) for _ in texts]
for i, text in enumerate(texts):
embedding = self._embedding_function(text)
metadata = metadatas[i] if metadatas else {}
_texts = list(texts)
ids = ids or [str(uuid.uuid4()) for _ in _texts]
embeddings = self._embeddings.embed_documents(_texts)
metadatas = metadatas or [{}] * len(_texts)
for metadata, text in zip(metadatas, _texts):
metadata[self._text_key] = text
docs.append((ids[i], embedding, metadata))
docs = list(zip(ids, embeddings, metadatas))
# upsert to Pinecone
self._index.upsert(vectors=docs, namespace=namespace, batch_size=batch_size)
return ids
@@ -109,7 +121,7 @@ class Pinecone(VectorStore):
"""
if namespace is None:
namespace = self._namespace
query_obj = self._embedding_function(query)
query_obj = self._embeddings.embed_query(query)
docs = []
results = self._index.query(
[query_obj],
@@ -145,7 +157,7 @@ class Pinecone(VectorStore):
"""
if namespace is None:
namespace = self._namespace
query_obj = self._embedding_function(query)
query_obj = self._embeddings.embed_query(query)
docs = []
results = self._index.query(
[query_obj],
@@ -244,7 +256,7 @@ class Pinecone(VectorStore):
# upsert to Pinecone
index.upsert(vectors=list(to_upsert), namespace=namespace)
return cls(index, embedding.embed_query, text_key, namespace)
return cls(index, embedding, text_key, namespace)
@classmethod
def from_existing_index(
@@ -263,6 +275,4 @@ class Pinecone(VectorStore):
"Please install it with `pip install pinecone-client`."
)
return cls(
pinecone.Index(index_name), embedding.embed_query, text_key, namespace
)
return cls(pinecone.Index(index_name), embedding, text_key, namespace)