mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
refactor(openai): embedding utils and calculations (#33982)
Now returns (`_iter`, `tokens`, `indices`, token_counts`). The `token_counts` are calculated directly during tokenization, which is more accurate and efficient than splitting strings later.
This commit is contained in:
@@ -421,28 +421,28 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def _tokenize(
|
||||
self, texts: list[str], chunk_size: int
|
||||
) -> tuple[Iterable[int], list[list[int] | str], list[int]]:
|
||||
"""Take the input `texts` and `chunk_size` and return 3 iterables as a tuple.
|
||||
) -> tuple[Iterable[int], list[list[int] | str], list[int], list[int]]:
|
||||
"""Tokenize and batch input texts.
|
||||
|
||||
We have `batches`, where batches are sets of individual texts
|
||||
we want responses from the openai api. The length of a single batch is
|
||||
`chunk_size` texts.
|
||||
Splits texts based on `embedding_ctx_length` and groups them into batches
|
||||
of size `chunk_size`.
|
||||
|
||||
Each individual text is also split into multiple texts based on the
|
||||
`embedding_ctx_length` parameter (based on number of tokens).
|
||||
Args:
|
||||
texts: The list of texts to tokenize.
|
||||
chunk_size: The maximum number of texts to include in a single batch.
|
||||
|
||||
This function returns a 3-tuple of the following:
|
||||
|
||||
_iter: An iterable of the starting index in `tokens` for each *batch*
|
||||
tokens: A list of tokenized texts, where each text has already been split
|
||||
into sub-texts based on the `embedding_ctx_length` parameter. In the
|
||||
case of tiktoken, this is a list of token arrays. In the case of
|
||||
HuggingFace transformers, this is a list of strings.
|
||||
indices: An iterable of the same length as `tokens` that maps each token-array
|
||||
to the index of the original text in `texts`.
|
||||
Returns:
|
||||
A tuple containing:
|
||||
1. An iterable of starting indices in the token list for each batch.
|
||||
2. A list of tokenized texts (token arrays for tiktoken, strings for
|
||||
HuggingFace).
|
||||
3. An iterable mapping each token array to the index of the original
|
||||
text. Same length as the token list.
|
||||
4. A list of token counts for each tokenized text.
|
||||
"""
|
||||
tokens: list[list[int] | str] = []
|
||||
indices: list[int] = []
|
||||
token_counts: list[int] = []
|
||||
model_name = self.tiktoken_model_name or self.model
|
||||
|
||||
# If tiktoken flag set to False
|
||||
@@ -474,6 +474,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
chunk_text: str = tokenizer.decode(token_chunk)
|
||||
tokens.append(chunk_text)
|
||||
indices.append(i)
|
||||
token_counts.append(len(token_chunk))
|
||||
else:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
@@ -503,6 +504,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
for j in range(0, len(token), self.embedding_ctx_length):
|
||||
tokens.append(token[j : j + self.embedding_ctx_length])
|
||||
indices.append(i)
|
||||
token_counts.append(len(token[j : j + self.embedding_ctx_length]))
|
||||
|
||||
if self.show_progress_bar:
|
||||
try:
|
||||
@@ -513,7 +515,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
_iter = range(0, len(tokens), chunk_size)
|
||||
else:
|
||||
_iter = range(0, len(tokens), chunk_size)
|
||||
return _iter, tokens, indices
|
||||
return _iter, tokens, indices, token_counts
|
||||
|
||||
# please refer to
|
||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||
@@ -527,12 +529,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
) -> list[list[float]]:
|
||||
"""Generate length-safe embeddings for a list of texts.
|
||||
|
||||
This method handles tokenization and embedding generation, respecting the set
|
||||
embedding context length and chunk size. It supports both tiktoken and
|
||||
HuggingFace tokenizer based on the tiktoken_enabled flag.
|
||||
This method handles tokenization and embedding generation, respecting the
|
||||
`embedding_ctx_length` and `chunk_size`. Supports both `tiktoken` and
|
||||
HuggingFace `transformers` based on the `tiktoken_enabled` flag.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to embed.
|
||||
texts: The list of texts to embed.
|
||||
engine: The engine or model to use for embeddings.
|
||||
chunk_size: The size of chunks for processing embeddings.
|
||||
|
||||
@@ -541,12 +543,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
|
||||
_iter, tokens, indices, token_counts = self._tokenize(texts, _chunk_size)
|
||||
batched_embeddings: list[list[float]] = []
|
||||
# Calculate token counts per chunk
|
||||
token_counts = [
|
||||
len(t) if isinstance(t, list) else len(t.split()) for t in tokens
|
||||
]
|
||||
|
||||
# Process in batches respecting the token limit
|
||||
i = 0
|
||||
@@ -603,12 +601,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
) -> list[list[float]]:
|
||||
"""Asynchronously generate length-safe embeddings for a list of texts.
|
||||
|
||||
This method handles tokenization and asynchronous embedding generation,
|
||||
respecting the set embedding context length and chunk size. It supports both
|
||||
`tiktoken` and HuggingFace `tokenizer` based on the tiktoken_enabled flag.
|
||||
This method handles tokenization and embedding generation, respecting the
|
||||
`embedding_ctx_length` and `chunk_size`. Supports both `tiktoken` and
|
||||
HuggingFace `transformers` based on the `tiktoken_enabled` flag.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to embed.
|
||||
texts: The list of texts to embed.
|
||||
engine: The engine or model to use for embeddings.
|
||||
chunk_size: The size of chunks for processing embeddings.
|
||||
|
||||
@@ -617,14 +615,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
_iter, tokens, indices = await run_in_executor(
|
||||
_iter, tokens, indices, token_counts = await run_in_executor(
|
||||
None, self._tokenize, texts, _chunk_size
|
||||
)
|
||||
batched_embeddings: list[list[float]] = []
|
||||
# Calculate token counts per chunk
|
||||
token_counts = [
|
||||
len(t) if isinstance(t, list) else len(t.split()) for t in tokens
|
||||
]
|
||||
|
||||
# Process in batches respecting the token limit
|
||||
i = 0
|
||||
@@ -676,12 +670,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
def embed_documents(
|
||||
self, texts: list[str], chunk_size: int | None = None, **kwargs: Any
|
||||
) -> list[list[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
"""Call OpenAI's embedding endpoint to embed search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If `None`, will use the chunk size
|
||||
specified by the class.
|
||||
chunk_size: The chunk size of embeddings.
|
||||
|
||||
If `None`, will use the chunk size specified by the class.
|
||||
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||
|
||||
Returns:
|
||||
@@ -701,8 +696,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
embeddings.extend(r["embedding"] for r in response["data"])
|
||||
return embeddings
|
||||
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
# Unconditionally call _get_len_safe_embeddings to handle length safety.
|
||||
# This could be optimized to avoid double work when all texts are short enough.
|
||||
engine = cast(str, self.deployment)
|
||||
return self._get_len_safe_embeddings(
|
||||
texts, engine=engine, chunk_size=chunk_size, **kwargs
|
||||
@@ -711,12 +706,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
async def aembed_documents(
|
||||
self, texts: list[str], chunk_size: int | None = None, **kwargs: Any
|
||||
) -> list[list[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
||||
"""Asynchronously call OpenAI's embedding endpoint to embed search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If `None`, will use the chunk size
|
||||
specified by the class.
|
||||
chunk_size: The chunk size of embeddings.
|
||||
|
||||
If `None`, will use the chunk size specified by the class.
|
||||
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||
|
||||
Returns:
|
||||
@@ -735,8 +731,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
embeddings.extend(r["embedding"] for r in response["data"])
|
||||
return embeddings
|
||||
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
# Unconditionally call _get_len_safe_embeddings to handle length safety.
|
||||
# This could be optimized to avoid double work when all texts are short enough.
|
||||
engine = cast(str, self.deployment)
|
||||
return await self._aget_len_safe_embeddings(
|
||||
texts, engine=engine, chunk_size=chunk_size, **kwargs
|
||||
|
||||
@@ -33,7 +33,7 @@ def test_embed_documents_with_custom_chunk_size() -> None:
|
||||
]
|
||||
|
||||
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
|
||||
_, tokens, __ = embeddings._tokenize(texts, custom_chunk_size)
|
||||
_, tokens, __, ___ = embeddings._tokenize(texts, custom_chunk_size)
|
||||
mock_create.call_args
|
||||
mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params)
|
||||
mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params)
|
||||
|
||||
6
libs/partners/openai/uv.lock
generated
6
libs/partners/openai/uv.lock
generated
@@ -566,7 +566,7 @@ requires-dist = [
|
||||
{ name = "langchain-groq", marker = "extra == 'groq'" },
|
||||
{ name = "langchain-huggingface", marker = "extra == 'huggingface'" },
|
||||
{ name = "langchain-mistralai", marker = "extra == 'mistralai'" },
|
||||
{ name = "langchain-model-profiles", marker = "extra == 'model-profiles'" },
|
||||
{ name = "langchain-model-profiles", marker = "extra == 'model-profiles'", editable = "../../model-profiles" },
|
||||
{ name = "langchain-ollama", marker = "extra == 'ollama'" },
|
||||
{ name = "langchain-openai", marker = "extra == 'openai'", editable = "." },
|
||||
{ name = "langchain-perplexity", marker = "extra == 'perplexity'" },
|
||||
@@ -580,6 +580,7 @@ provides-extras = ["model-profiles", "community", "anthropic", "openai", "azure-
|
||||
[package.metadata.requires-dev]
|
||||
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13.0" }]
|
||||
test = [
|
||||
{ name = "langchain-model-profiles", editable = "../../model-profiles" },
|
||||
{ name = "langchain-openai", editable = "." },
|
||||
{ name = "langchain-tests", editable = "../../standard-tests" },
|
||||
{ name = "pytest", specifier = ">=8.0.0,<9.0.0" },
|
||||
@@ -595,6 +596,7 @@ test = [
|
||||
test-integration = [
|
||||
{ name = "cassio", specifier = ">=0.1.0,<1.0.0" },
|
||||
{ name = "langchain-core", editable = "../../core" },
|
||||
{ name = "langchain-model-profiles", editable = "../../model-profiles" },
|
||||
{ name = "langchain-text-splitters", editable = "../../text-splitters" },
|
||||
{ name = "langchainhub", specifier = ">=0.1.16,<1.0.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.0,<2.0.0" },
|
||||
@@ -608,7 +610,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.4"
|
||||
version = "1.0.5"
|
||||
source = { editable = "../../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
|
||||
Reference in New Issue
Block a user