mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
community[minor]: New model parameters and dynamic batching for VertexAIEmbeddings (#13999)
- **Description:** VertexAIEmbeddings performance improvements - **Twitter handle:** @vladkol ## Improvements - Dynamic batch size, starting from 250, lowering down to 5. Batch size varies across regions. Some regions support larger batches, and it significantly improves performance. When running large batches of texts in `us-central1`, performance gain can be up to 3.5x. The dynamic batching also makes sure every batch is below 20K token limit. - New model parameter `embeddings_type` that translates to `task_type` parameter of the API. Newer model versions support [different embeddings task types](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings#api_changes_to_models_released_on_or_after_august_2023).
This commit is contained in:
parent
2e6a9e6381
commit
11fda490ca
@ -1,16 +1,29 @@
|
||||
from typing import Dict, List
|
||||
import logging
|
||||
import re
|
||||
import string
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
from langchain_community.llms.vertexai import _VertexAICommon
|
||||
from langchain_community.utilities.vertexai import raise_vertex_import_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_TOKENS_PER_BATCH = 20000
|
||||
_MAX_BATCH_SIZE = 250
|
||||
_MIN_BATCH_SIZE = 5
|
||||
|
||||
|
||||
class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
||||
"""Google Cloud VertexAI embedding models."""
|
||||
|
||||
model_name: str = "textembedding-gecko"
|
||||
# Instance context
|
||||
instance: Dict[str, Any] = {} #: :meta private:
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -18,31 +31,294 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
||||
cls._try_init_vertexai(values)
|
||||
try:
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
|
||||
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
||||
return values
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], batch_size: int = 5
|
||||
def __init__(
|
||||
self,
|
||||
project: Optional[str] = None,
|
||||
location: str = "us-central1",
|
||||
request_parallelism: int = 5,
|
||||
max_retries: int = 6,
|
||||
model_name: str = "textembedding-gecko",
|
||||
credentials: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
request_parallelism=request_parallelism,
|
||||
max_retries=max_retries,
|
||||
model_name=model_name,
|
||||
**kwargs,
|
||||
)
|
||||
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
|
||||
self.instance["batch_size"] = self.instance["max_batch_size"]
|
||||
self.instance["min_batch_size"] = kwargs.get("min_batch_size", _MIN_BATCH_SIZE)
|
||||
self.instance["min_good_batch_size"] = self.instance["min_batch_size"]
|
||||
self.instance["lock"] = threading.Lock()
|
||||
self.instance["batch_size_validated"] = False
|
||||
self.instance["task_executor"] = ThreadPoolExecutor(
|
||||
max_workers=request_parallelism
|
||||
)
|
||||
self.instance[
|
||||
"embeddings_task_type_supported"
|
||||
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")
|
||||
|
||||
@staticmethod
|
||||
def _split_by_punctuation(text: str) -> List[str]:
|
||||
"""Splits a string by punctuation and whitespace characters."""
|
||||
split_by = string.punctuation + "\t\n "
|
||||
pattern = f"([{split_by}])"
|
||||
# Using re.split to split the text based on the pattern
|
||||
return [segment for segment in re.split(pattern, text) if segment]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
|
||||
"""Splits texts in batches based on current maximum batch size
|
||||
and maximum tokens per request.
|
||||
"""
|
||||
text_index = 0
|
||||
texts_len = len(texts)
|
||||
batch_token_len = 0
|
||||
batches: List[List[str]] = []
|
||||
current_batch: List[str] = []
|
||||
if texts_len == 0:
|
||||
return []
|
||||
while text_index < texts_len:
|
||||
current_text = texts[text_index]
|
||||
# Number of tokens per a text is conservatively estimated
|
||||
# as 2 times number of words, punctuation and whitespace characters.
|
||||
# Using `count_tokens` API will make batching too expensive.
|
||||
# Utilizing a tokenizer, would add a dependency that would not
|
||||
# necessarily be reused by the application using this class.
|
||||
current_text_token_cnt = (
|
||||
len(VertexAIEmbeddings._split_by_punctuation(current_text)) * 2
|
||||
)
|
||||
end_of_batch = False
|
||||
if current_text_token_cnt > _MAX_TOKENS_PER_BATCH:
|
||||
# Current text is too big even for a single batch.
|
||||
# Such request will fail, but we still make a batch
|
||||
# so that the app can get the error from the API.
|
||||
if len(current_batch) > 0:
|
||||
# Adding current batch if not empty.
|
||||
batches.append(current_batch)
|
||||
current_batch = [current_text]
|
||||
text_index += 1
|
||||
end_of_batch = True
|
||||
elif (
|
||||
batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
|
||||
or len(current_batch) == batch_size
|
||||
):
|
||||
end_of_batch = True
|
||||
else:
|
||||
if text_index == texts_len - 1:
|
||||
# Last element - even though the batch may be not big,
|
||||
# we still need to make it.
|
||||
end_of_batch = True
|
||||
batch_token_len += current_text_token_cnt
|
||||
current_batch.append(current_text)
|
||||
text_index += 1
|
||||
if end_of_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = []
|
||||
batch_token_len = 0
|
||||
return batches
|
||||
|
||||
def _get_embeddings_with_retry(
|
||||
self, texts: List[str], embeddings_type: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of strings. Vertex AI currently
|
||||
sets a max batch size of 5 strings.
|
||||
"""Makes a Vertex AI model request with retry logic."""
|
||||
from google.api_core.exceptions import (
|
||||
Aborted,
|
||||
DeadlineExceeded,
|
||||
ResourceExhausted,
|
||||
ServiceUnavailable,
|
||||
)
|
||||
|
||||
errors = [
|
||||
ResourceExhausted,
|
||||
ServiceUnavailable,
|
||||
Aborted,
|
||||
DeadlineExceeded,
|
||||
]
|
||||
retry_decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=self.max_retries
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(texts_to_process: List[str]) -> Any:
|
||||
if embeddings_type and self.instance["embeddings_task_type_supported"]:
|
||||
from vertexai.language_models import TextEmbeddingInput
|
||||
|
||||
requests = [
|
||||
TextEmbeddingInput(text=t, task_type=embeddings_type)
|
||||
for t in texts_to_process
|
||||
]
|
||||
else:
|
||||
requests = texts_to_process
|
||||
embeddings = self.client.get_embeddings(requests)
|
||||
return [embs.values for embs in embeddings]
|
||||
|
||||
return _completion_with_retry(texts)
|
||||
|
||||
def _prepare_and_validate_batches(
|
||||
self, texts: List[str], embeddings_type: Optional[str] = None
|
||||
) -> Tuple[List[List[float]], List[List[str]]]:
|
||||
"""Prepares text batches with one-time validation of batch size.
|
||||
Batch size varies between GCP regions and individual project quotas.
|
||||
# Returns embeddings of the first text batch that went through,
|
||||
# and text batches for the rest of the texts.
|
||||
"""
|
||||
from google.api_core.exceptions import InvalidArgument
|
||||
|
||||
batches = VertexAIEmbeddings._prepare_batches(
|
||||
texts, self.instance["batch_size"]
|
||||
)
|
||||
# If batch size if less or equal to one that went through before,
|
||||
# then keep batches as they are.
|
||||
if len(batches[0]) <= self.instance["min_good_batch_size"]:
|
||||
return [], batches
|
||||
with self.instance["lock"]:
|
||||
# If largest possible batch size was validated
|
||||
# while waiting for the lock, then check for rebuilding
|
||||
# our batches, and return.
|
||||
if self.instance["batch_size_validated"]:
|
||||
if len(batches[0]) <= self.instance["batch_size"]:
|
||||
return [], batches
|
||||
else:
|
||||
return [], VertexAIEmbeddings._prepare_batches(
|
||||
texts, self.instance["batch_size"]
|
||||
)
|
||||
# Figure out largest possible batch size by trying to push
|
||||
# batches and lowering their size in half after every failure.
|
||||
first_batch = batches[0]
|
||||
first_result = []
|
||||
had_failure = False
|
||||
while True:
|
||||
try:
|
||||
first_result = self._get_embeddings_with_retry(
|
||||
first_batch, embeddings_type
|
||||
)
|
||||
break
|
||||
except InvalidArgument:
|
||||
had_failure = True
|
||||
first_batch_len = len(first_batch)
|
||||
if first_batch_len == self.instance["min_batch_size"]:
|
||||
raise
|
||||
first_batch_len = max(
|
||||
self.instance["min_batch_size"], int(first_batch_len / 2)
|
||||
)
|
||||
first_batch = first_batch[:first_batch_len]
|
||||
first_batch_len = len(first_batch)
|
||||
self.instance["min_good_batch_size"] = max(
|
||||
self.instance["min_good_batch_size"], first_batch_len
|
||||
)
|
||||
# If had a failure and recovered
|
||||
# or went through with the max size, then it's a legit batch size.
|
||||
if had_failure or first_batch_len == self.instance["max_batch_size"]:
|
||||
self.instance["batch_size"] = first_batch_len
|
||||
self.instance["batch_size_validated"] = True
|
||||
# If batch size was updated,
|
||||
# rebuild batches with the new batch size
|
||||
# (texts that went through are excluded here).
|
||||
if first_batch_len != self.instance["max_batch_size"]:
|
||||
batches = VertexAIEmbeddings._prepare_batches(
|
||||
texts[first_batch_len:], self.instance["batch_size"]
|
||||
)
|
||||
else:
|
||||
# Still figuring out max batch size.
|
||||
batches = batches[1:]
|
||||
# Returning embeddings of the first text batch that went through,
|
||||
# and text batches for the rest of texts.
|
||||
return first_result, batches
|
||||
|
||||
def embed(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int = 0,
|
||||
embeddings_task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
]
|
||||
] = None,
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of strings.
|
||||
|
||||
Args:
|
||||
texts: List[str] The list of strings to embed.
|
||||
batch_size: [int] The batch size of embeddings to send to the model
|
||||
batch_size: [int] The batch size of embeddings to send to the model.
|
||||
If zero, then the largest batch size will be detected dynamically
|
||||
at the first request, starting from 250, down to 5.
|
||||
embeddings_task_type: [str] optional embeddings task type,
|
||||
one of the following
|
||||
RETRIEVAL_QUERY - Text is a query
|
||||
in a search/retrieval setting.
|
||||
RETRIEVAL_DOCUMENT - Text is a document
|
||||
in a search/retrieval setting.
|
||||
SEMANTIC_SIMILARITY - Embeddings will be used
|
||||
for Semantic Textual Similarity (STS).
|
||||
CLASSIFICATION - Embeddings will be used for classification.
|
||||
CLUSTERING - Embeddings will be used for clustering.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for batch in range(0, len(texts), batch_size):
|
||||
text_batch = texts[batch : batch + batch_size]
|
||||
embeddings_batch = self.client.get_embeddings(text_batch)
|
||||
embeddings.extend([el.values for el in embeddings_batch])
|
||||
if len(texts) == 0:
|
||||
return []
|
||||
embeddings: List[List[float]] = []
|
||||
first_batch_result: List[List[float]] = []
|
||||
if batch_size > 0:
|
||||
# Fixed batch size.
|
||||
batches = VertexAIEmbeddings._prepare_batches(texts, batch_size)
|
||||
else:
|
||||
# Dynamic batch size, starting from 250 at the first call.
|
||||
first_batch_result, batches = self._prepare_and_validate_batches(
|
||||
texts, embeddings_task_type
|
||||
)
|
||||
# First batch result may have some embeddings already.
|
||||
# In such case, batches have texts that were not processed yet.
|
||||
embeddings.extend(first_batch_result)
|
||||
tasks = []
|
||||
for batch in batches:
|
||||
tasks.append(
|
||||
self.instance["task_executor"].submit(
|
||||
self._get_embeddings_with_retry,
|
||||
texts=batch,
|
||||
embeddings_type=embeddings_task_type,
|
||||
)
|
||||
)
|
||||
if len(tasks) > 0:
|
||||
wait(tasks)
|
||||
for t in tasks:
|
||||
embeddings.extend(t.result())
|
||||
return embeddings
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], batch_size: int = 0
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of documents.
|
||||
|
||||
Args:
|
||||
texts: List[str] The list of texts to embed.
|
||||
batch_size: [int] The batch size of embeddings to send to the model.
|
||||
If zero, then the largest batch size will be detected dynamically
|
||||
at the first request, starting from 250, down to 5.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a text.
|
||||
|
||||
@ -52,5 +328,5 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embeddings = self.client.get_embeddings([text])
|
||||
return embeddings[0].values
|
||||
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
|
||||
return embeddings[0]
|
||||
|
@ -1,8 +1,8 @@
|
||||
"""Test Vertex AI API wrapper.
|
||||
In order to run this test, you need to install VertexAI SDK
|
||||
In order to run this test, you need to install VertexAI SDK
|
||||
pip install google-cloud-aiplatform>=1.35.0
|
||||
|
||||
Your end-user credentials would be used to make the calls (make sure you've run
|
||||
Your end-user credentials would be used to make the calls (make sure you've run
|
||||
`gcloud auth login` first).
|
||||
"""
|
||||
from langchain_community.embeddings import VertexAIEmbeddings
|
||||
@ -24,6 +24,16 @@ def test_embedding_query() -> None:
|
||||
assert len(output) == 768
|
||||
|
||||
|
||||
def test_large_batches() -> None:
|
||||
documents = ["foo bar" for _ in range(0, 251)]
|
||||
model_uscentral1 = VertexAIEmbeddings(location="us-central1")
|
||||
model_asianortheast1 = VertexAIEmbeddings(location="asia-northeast1")
|
||||
model_uscentral1.embed_documents(documents)
|
||||
model_asianortheast1.embed_documents(documents)
|
||||
assert model_uscentral1.instance["batch_size"] >= 250
|
||||
assert model_asianortheast1.instance["batch_size"] < 50
|
||||
|
||||
|
||||
def test_paginated_texts() -> None:
|
||||
documents = [
|
||||
"foo bar",
|
||||
|
63
libs/community/tests/unit_tests/embeddings/test_vertexai.py
Normal file
63
libs/community/tests/unit_tests/embeddings/test_vertexai.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Test Vertex AI embeddings API wrapper.
|
||||
"""
|
||||
|
||||
from langchain_community.embeddings import VertexAIEmbeddings
|
||||
|
||||
|
||||
def test_split_by_punctuation() -> None:
|
||||
parts = VertexAIEmbeddings._split_by_punctuation(
|
||||
"Hello, my friend!\nHow are you?\nI have 2 news:\n\n\t- Good,\n\t- Bad."
|
||||
)
|
||||
assert parts == [
|
||||
"Hello",
|
||||
",",
|
||||
" ",
|
||||
"my",
|
||||
" ",
|
||||
"friend",
|
||||
"!",
|
||||
"\n",
|
||||
"How",
|
||||
" ",
|
||||
"are",
|
||||
" ",
|
||||
"you",
|
||||
"?",
|
||||
"\n",
|
||||
"I",
|
||||
" ",
|
||||
"have",
|
||||
" ",
|
||||
"2",
|
||||
" ",
|
||||
"news",
|
||||
":",
|
||||
"\n",
|
||||
"\n",
|
||||
"\t",
|
||||
"-",
|
||||
" ",
|
||||
"Good",
|
||||
",",
|
||||
"\n",
|
||||
"\t",
|
||||
"-",
|
||||
" ",
|
||||
"Bad",
|
||||
".",
|
||||
]
|
||||
|
||||
|
||||
def test_batching() -> None:
|
||||
long_text = "foo " * 500 # 1000 words, 2000 tokens
|
||||
long_texts = [long_text for _ in range(0, 250)]
|
||||
documents251 = ["foo bar" for _ in range(0, 251)]
|
||||
five_elem = VertexAIEmbeddings._prepare_batches(long_texts, 5)
|
||||
default250_elem = VertexAIEmbeddings._prepare_batches(long_texts, 250)
|
||||
batches251 = VertexAIEmbeddings._prepare_batches(documents251, 250)
|
||||
assert len(five_elem) == 50 # 250/5 items
|
||||
assert len(five_elem[0]) == 5 # 5 items per batch
|
||||
assert len(default250_elem[0]) == 10 # Should not be more than 20K tokens
|
||||
assert len(default250_elem) == 25
|
||||
assert len(batches251[0]) == 250
|
||||
assert len(batches251[1]) == 1
|
Loading…
Reference in New Issue
Block a user