mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
Add pagination for Vertex AI embeddings (#5325)
Fixes #5316 --------- Co-authored-by: Justin Flick <jflick@homesite.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
3e16468423
commit
c09f8e4ddc
@ -22,17 +22,25 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
|||||||
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(
|
||||||
"""Embed a list of strings.
|
self, texts: List[str], batch_size: int = 5
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""Embed a list of strings. Vertex AI currently
|
||||||
|
sets a max batch size of 5 strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List[str] The list of strings to embed.
|
texts: List[str] The list of strings to embed.
|
||||||
|
batch_size: [int] The batch size of embeddings to send to the model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
embeddings = self.client.get_embeddings(texts)
|
embeddings = []
|
||||||
return [el.values for el in embeddings]
|
for batch in range(0, len(texts), batch_size):
|
||||||
|
text_batch = texts[batch : batch + batch_size]
|
||||||
|
embeddings_batch = self.client.get_embeddings(text_batch)
|
||||||
|
embeddings.extend([el.values for el in embeddings_batch])
|
||||||
|
return embeddings
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Embed a text.
|
"""Embed a text.
|
||||||
|
@ -23,3 +23,22 @@ def test_embedding_query() -> None:
|
|||||||
model = VertexAIEmbeddings()
|
model = VertexAIEmbeddings()
|
||||||
output = model.embed_query(document)
|
output = model.embed_query(document)
|
||||||
assert len(output) == 768
|
assert len(output) == 768
|
||||||
|
|
||||||
|
|
||||||
|
def test_paginated_texts() -> None:
|
||||||
|
documents = [
|
||||||
|
"foo bar",
|
||||||
|
"foo baz",
|
||||||
|
"bar foo",
|
||||||
|
"baz foo",
|
||||||
|
"bar bar",
|
||||||
|
"foo foo",
|
||||||
|
"baz baz",
|
||||||
|
"baz bar",
|
||||||
|
]
|
||||||
|
model = VertexAIEmbeddings()
|
||||||
|
output = model.embed_documents(documents)
|
||||||
|
assert len(output) == 8
|
||||||
|
assert len(output[0]) == 768
|
||||||
|
assert model._llm_type == "vertexai"
|
||||||
|
assert model.model_name == model.client._model_id
|
||||||
|
Loading…
Reference in New Issue
Block a user