mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
Add pagination for Vertex AI embeddings (#5325)
Fixes #5316 --------- Co-authored-by: Justin Flick <jflick@homesite.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
3e16468423
commit
c09f8e4ddc
@ -22,17 +22,25 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
||||
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of strings.
|
||||
def embed_documents(
|
||||
self, texts: List[str], batch_size: int = 5
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of strings. Vertex AI currently
|
||||
sets a max batch size of 5 strings.
|
||||
|
||||
Args:
|
||||
texts: List[str] The list of strings to embed.
|
||||
batch_size: [int] The batch size of embeddings to send to the model
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = self.client.get_embeddings(texts)
|
||||
return [el.values for el in embeddings]
|
||||
embeddings = []
|
||||
for batch in range(0, len(texts), batch_size):
|
||||
text_batch = texts[batch : batch + batch_size]
|
||||
embeddings_batch = self.client.get_embeddings(text_batch)
|
||||
embeddings.extend([el.values for el in embeddings_batch])
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a text.
|
||||
|
@ -23,3 +23,22 @@ def test_embedding_query() -> None:
|
||||
model = VertexAIEmbeddings()
|
||||
output = model.embed_query(document)
|
||||
assert len(output) == 768
|
||||
|
||||
|
||||
def test_paginated_texts() -> None:
|
||||
documents = [
|
||||
"foo bar",
|
||||
"foo baz",
|
||||
"bar foo",
|
||||
"baz foo",
|
||||
"bar bar",
|
||||
"foo foo",
|
||||
"baz baz",
|
||||
"baz bar",
|
||||
]
|
||||
model = VertexAIEmbeddings()
|
||||
output = model.embed_documents(documents)
|
||||
assert len(output) == 8
|
||||
assert len(output[0]) == 768
|
||||
assert model._llm_type == "vertexai"
|
||||
assert model.model_name == model.client._model_id
|
||||
|
Loading…
Reference in New Issue
Block a user