refactor(openai): embedding utils and calculations (#33982)

Now returns (`_iter`, `tokens`, `indices`, token_counts`). The
`token_counts` are calculated directly during tokenization, which is
more accurate and efficient than splitting strings later.
This commit is contained in:
Mason Daugherty
2025-11-14 19:18:37 -05:00
committed by GitHub
parent 2d4f00a451
commit 099c042395
3 changed files with 46 additions and 48 deletions

View File

@@ -421,28 +421,28 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
def _tokenize(
self, texts: list[str], chunk_size: int
) -> tuple[Iterable[int], list[list[int] | str], list[int]]:
"""Take the input `texts` and `chunk_size` and return 3 iterables as a tuple.
) -> tuple[Iterable[int], list[list[int] | str], list[int], list[int]]:
"""Tokenize and batch input texts.
We have `batches`, where batches are sets of individual texts
we want responses from the openai api. The length of a single batch is
`chunk_size` texts.
Splits texts based on `embedding_ctx_length` and groups them into batches
of size `chunk_size`.
Each individual text is also split into multiple texts based on the
`embedding_ctx_length` parameter (based on number of tokens).
Args:
texts: The list of texts to tokenize.
chunk_size: The maximum number of texts to include in a single batch.
This function returns a 3-tuple of the following:
_iter: An iterable of the starting index in `tokens` for each *batch*
tokens: A list of tokenized texts, where each text has already been split
into sub-texts based on the `embedding_ctx_length` parameter. In the
case of tiktoken, this is a list of token arrays. In the case of
HuggingFace transformers, this is a list of strings.
indices: An iterable of the same length as `tokens` that maps each token-array
to the index of the original text in `texts`.
Returns:
A tuple containing:
1. An iterable of starting indices in the token list for each batch.
2. A list of tokenized texts (token arrays for tiktoken, strings for
HuggingFace).
3. An iterable mapping each token array to the index of the original
text. Same length as the token list.
4. A list of token counts for each tokenized text.
"""
tokens: list[list[int] | str] = []
indices: list[int] = []
token_counts: list[int] = []
model_name = self.tiktoken_model_name or self.model
# If tiktoken flag set to False
@@ -474,6 +474,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
chunk_text: str = tokenizer.decode(token_chunk)
tokens.append(chunk_text)
indices.append(i)
token_counts.append(len(token_chunk))
else:
try:
encoding = tiktoken.encoding_for_model(model_name)
@@ -503,6 +504,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for j in range(0, len(token), self.embedding_ctx_length):
tokens.append(token[j : j + self.embedding_ctx_length])
indices.append(i)
token_counts.append(len(token[j : j + self.embedding_ctx_length]))
if self.show_progress_bar:
try:
@@ -513,7 +515,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
_iter = range(0, len(tokens), chunk_size)
else:
_iter = range(0, len(tokens), chunk_size)
return _iter, tokens, indices
return _iter, tokens, indices, token_counts
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
@@ -527,12 +529,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
) -> list[list[float]]:
"""Generate length-safe embeddings for a list of texts.
This method handles tokenization and embedding generation, respecting the set
embedding context length and chunk size. It supports both tiktoken and
HuggingFace tokenizer based on the tiktoken_enabled flag.
This method handles tokenization and embedding generation, respecting the
`embedding_ctx_length` and `chunk_size`. Supports both `tiktoken` and
HuggingFace `transformers` based on the `tiktoken_enabled` flag.
Args:
texts: A list of texts to embed.
texts: The list of texts to embed.
engine: The engine or model to use for embeddings.
chunk_size: The size of chunks for processing embeddings.
@@ -541,12 +543,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
_chunk_size = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
_iter, tokens, indices, token_counts = self._tokenize(texts, _chunk_size)
batched_embeddings: list[list[float]] = []
# Calculate token counts per chunk
token_counts = [
len(t) if isinstance(t, list) else len(t.split()) for t in tokens
]
# Process in batches respecting the token limit
i = 0
@@ -603,12 +601,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
) -> list[list[float]]:
"""Asynchronously generate length-safe embeddings for a list of texts.
This method handles tokenization and asynchronous embedding generation,
respecting the set embedding context length and chunk size. It supports both
`tiktoken` and HuggingFace `tokenizer` based on the tiktoken_enabled flag.
This method handles tokenization and embedding generation, respecting the
`embedding_ctx_length` and `chunk_size`. Supports both `tiktoken` and
HuggingFace `transformers` based on the `tiktoken_enabled` flag.
Args:
texts: A list of texts to embed.
texts: The list of texts to embed.
engine: The engine or model to use for embeddings.
chunk_size: The size of chunks for processing embeddings.
@@ -617,14 +615,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
_chunk_size = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
_iter, tokens, indices = await run_in_executor(
_iter, tokens, indices, token_counts = await run_in_executor(
None, self._tokenize, texts, _chunk_size
)
batched_embeddings: list[list[float]] = []
# Calculate token counts per chunk
token_counts = [
len(t) if isinstance(t, list) else len(t.split()) for t in tokens
]
# Process in batches respecting the token limit
i = 0
@@ -676,12 +670,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
def embed_documents(
self, texts: list[str], chunk_size: int | None = None, **kwargs: Any
) -> list[list[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.
"""Call OpenAI's embedding endpoint to embed search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If `None`, will use the chunk size
specified by the class.
chunk_size: The chunk size of embeddings.
If `None`, will use the chunk size specified by the class.
kwargs: Additional keyword arguments to pass to the embedding API.
Returns:
@@ -701,8 +696,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
embeddings.extend(r["embedding"] for r in response["data"])
return embeddings
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
# Unconditionally call _get_len_safe_embeddings to handle length safety.
# This could be optimized to avoid double work when all texts are short enough.
engine = cast(str, self.deployment)
return self._get_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size, **kwargs
@@ -711,12 +706,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
async def aembed_documents(
self, texts: list[str], chunk_size: int | None = None, **kwargs: Any
) -> list[list[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
"""Asynchronously call OpenAI's embedding endpoint to embed search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If `None`, will use the chunk size
specified by the class.
chunk_size: The chunk size of embeddings.
If `None`, will use the chunk size specified by the class.
kwargs: Additional keyword arguments to pass to the embedding API.
Returns:
@@ -735,8 +731,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
embeddings.extend(r["embedding"] for r in response["data"])
return embeddings
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
# Unconditionally call _get_len_safe_embeddings to handle length safety.
# This could be optimized to avoid double work when all texts are short enough.
engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size, **kwargs

View File

@@ -33,7 +33,7 @@ def test_embed_documents_with_custom_chunk_size() -> None:
]
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
_, tokens, __ = embeddings._tokenize(texts, custom_chunk_size)
_, tokens, __, ___ = embeddings._tokenize(texts, custom_chunk_size)
mock_create.call_args
mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params)
mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params)

View File

@@ -566,7 +566,7 @@ requires-dist = [
{ name = "langchain-groq", marker = "extra == 'groq'" },
{ name = "langchain-huggingface", marker = "extra == 'huggingface'" },
{ name = "langchain-mistralai", marker = "extra == 'mistralai'" },
{ name = "langchain-model-profiles", marker = "extra == 'model-profiles'" },
{ name = "langchain-model-profiles", marker = "extra == 'model-profiles'", editable = "../../model-profiles" },
{ name = "langchain-ollama", marker = "extra == 'ollama'" },
{ name = "langchain-openai", marker = "extra == 'openai'", editable = "." },
{ name = "langchain-perplexity", marker = "extra == 'perplexity'" },
@@ -580,6 +580,7 @@ provides-extras = ["model-profiles", "community", "anthropic", "openai", "azure-
[package.metadata.requires-dev]
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13.0" }]
test = [
{ name = "langchain-model-profiles", editable = "../../model-profiles" },
{ name = "langchain-openai", editable = "." },
{ name = "langchain-tests", editable = "../../standard-tests" },
{ name = "pytest", specifier = ">=8.0.0,<9.0.0" },
@@ -595,6 +596,7 @@ test = [
test-integration = [
{ name = "cassio", specifier = ">=0.1.0,<1.0.0" },
{ name = "langchain-core", editable = "../../core" },
{ name = "langchain-model-profiles", editable = "../../model-profiles" },
{ name = "langchain-text-splitters", editable = "../../text-splitters" },
{ name = "langchainhub", specifier = ">=0.1.16,<1.0.0" },
{ name = "python-dotenv", specifier = ">=1.0.0,<2.0.0" },
@@ -608,7 +610,7 @@ typing = [
[[package]]
name = "langchain-core"
version = "1.0.4"
version = "1.0.5"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },