diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 6cc017ad480..3d8b65b823f 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -37,6 +37,67 @@ from langchain_core.utils import ( logger = logging.getLogger(__name__) +def _process_batched_chunked_embeddings( + num_texts: int, + tokens: List[Union[List[int], str]], + batched_embeddings: List[List[float]], + indices: List[int], + skip_empty: bool, +) -> List[Optional[List[float]]]: + # for each text, this is the list of embeddings (list of list of floats) + # corresponding to the chunks of the text + results: List[List[List[float]]] = [[] for _ in range(num_texts)] + + # for each text, this is the token length of each chunk + # for transformers tokenization, this is the string length + # for tiktoken, this is the number of tokens + num_tokens_in_batch: List[List[int]] = [[] for _ in range(num_texts)] + + for i in range(len(indices)): + if skip_empty and len(batched_embeddings[i]) == 1: + continue + results[indices[i]].append(batched_embeddings[i]) + num_tokens_in_batch[indices[i]].append(len(tokens[i])) + + # for each text, this is the final embedding + embeddings: List[Optional[List[float]]] = [] + for i in range(num_texts): + # an embedding for each chunk + _result: List[List[float]] = results[i] + + if len(_result) == 0: + # this will be populated with the embedding of an empty string + # in the sync or async code calling this + embeddings.append(None) + continue + + elif len(_result) == 1: + # if only one embedding was produced, use it + embeddings.append(_result[0]) + continue + + else: + # else we need to weighted average + # should be same as + # average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + total_weight = sum(num_tokens_in_batch[i]) + average = [ + sum( + val * weight + for val, weight in zip(embedding, num_tokens_in_batch[i]) + ) + / total_weight + for embedding in zip(*_result) + ] + + # should be same as + # embeddings.append((average / np.linalg.norm(average)).tolist()) + magnitude = sum(val**2 for val in average) ** 0.5 + embeddings.append([val / magnitude for val in average]) + + return embeddings + + class OpenAIEmbeddings(BaseModel, Embeddings): """OpenAI embedding models. @@ -248,9 +309,29 @@ class OpenAIEmbeddings(BaseModel, Embeddings): def _tokenize( self, texts: List[str], chunk_size: int - ) -> Tuple[Iterable[int], List[List[float]], List[int]]: - tokens = [] - indices = [] + ) -> Tuple[Iterable[int], List[Union[List[int], str]], List[int]]: + """ + Take the input `texts` and `chunk_size` and return 3 iterables as a tuple: + + We have `batches`, where batches are sets of individual texts + we want responses from the openai api. The length of a single batch is + `chunk_size` texts. + + Each individual text is also split into multiple texts based on the + `embedding_ctx_length` parameter (based on number of tokens). + + This function returns a 3-tuple of the following: + + _iter: An iterable of the starting index in `tokens` for each *batch* + tokens: A list of tokenized texts, where each text has already been split + into sub-texts based on the `embedding_ctx_length` parameter. In the + case of tiktoken, this is a list of token arrays. In the case of + HuggingFace transformers, this is a list of strings. + indices: An iterable of the same length as `tokens` that maps each token-array + to the index of the original text in `texts`. + """ + tokens: List[Union[List[int], str]] = [] + indices: List[int] = [] model_name = self.tiktoken_model_name or self.model # If tiktoken flag set to False @@ -269,14 +350,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) for i, text in enumerate(texts): # Tokenize the text using HuggingFace transformers - tokenized = tokenizer.encode(text, add_special_tokens=False) + tokenized: List[int] = tokenizer.encode(text, add_special_tokens=False) # Split tokens into chunks respecting the embedding_ctx_length for j in range(0, len(tokenized), self.embedding_ctx_length): - token_chunk = tokenized[j : j + self.embedding_ctx_length] + token_chunk: List[int] = tokenized[ + j : j + self.embedding_ctx_length + ] # Convert token IDs back to a string - chunk_text = tokenizer.decode(token_chunk) + chunk_text: str = tokenizer.decode(token_chunk) tokens.append(chunk_text) indices.append(i) else: @@ -351,43 +434,23 @@ class OpenAIEmbeddings(BaseModel, Embeddings): response = response.model_dump() batched_embeddings.extend(r["embedding"] for r in response["data"]) - results: List[List[List[float]]] = [[] for _ in range(len(texts))] - num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] - for i in range(len(indices)): - if self.skip_empty and len(batched_embeddings[i]) == 1: - continue - results[indices[i]].append(batched_embeddings[i]) - num_tokens_in_batch[indices[i]].append(len(tokens[i])) + embeddings = _process_batched_chunked_embeddings( + len(texts), tokens, batched_embeddings, indices, self.skip_empty + ) + _cached_empty_embedding: Optional[List[float]] = None - embeddings: List[List[float]] = [[] for _ in range(len(texts))] - for i in range(len(texts)): - _result = results[i] - if len(_result) == 0: + def empty_embedding() -> List[float]: + nonlocal _cached_empty_embedding + if _cached_empty_embedding is None: average_embedded = self.client.create( input="", **self._invocation_params ) if not isinstance(average_embedded, dict): average_embedded = average_embedded.model_dump() - average = average_embedded["data"][0]["embedding"] - else: - # should be same as - # average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - total_weight = sum(num_tokens_in_batch[i]) - average = [ - sum( - val * weight - for val, weight in zip(embedding, num_tokens_in_batch[i]) - ) - / total_weight - for embedding in zip(*_result) - ] + _cached_empty_embedding = average_embedded["data"][0]["embedding"] + return _cached_empty_embedding - # should be same as - # embeddings[i] = (average / np.linalg.norm(average)).tolist() - magnitude = sum(val**2 for val in average) ** 0.5 - embeddings[i] = [val / magnitude for val in average] - - return embeddings + return [e if e is not None else empty_embedding() for e in embeddings] # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb @@ -423,40 +486,23 @@ class OpenAIEmbeddings(BaseModel, Embeddings): response = response.model_dump() batched_embeddings.extend(r["embedding"] for r in response["data"]) - results: List[List[List[float]]] = [[] for _ in range(len(texts))] - num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] - for i in range(len(indices)): - results[indices[i]].append(batched_embeddings[i]) - num_tokens_in_batch[indices[i]].append(len(tokens[i])) + embeddings = _process_batched_chunked_embeddings( + len(texts), tokens, batched_embeddings, indices, self.skip_empty + ) + _cached_empty_embedding: Optional[List[float]] = None - embeddings: List[List[float]] = [[] for _ in range(len(texts))] - for i in range(len(texts)): - _result = results[i] - if len(_result) == 0: + async def empty_embedding() -> List[float]: + nonlocal _cached_empty_embedding + if _cached_empty_embedding is None: average_embedded = await self.async_client.create( input="", **self._invocation_params ) if not isinstance(average_embedded, dict): average_embedded = average_embedded.model_dump() - average = average_embedded["data"][0]["embedding"] - else: - # should be same as - # average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - total_weight = sum(num_tokens_in_batch[i]) - average = [ - sum( - val * weight - for val, weight in zip(embedding, num_tokens_in_batch[i]) - ) - / total_weight - for embedding in zip(*_result) - ] - # should be same as - # embeddings[i] = (average / np.linalg.norm(average)).tolist() - magnitude = sum(val**2 for val in average) ** 0.5 - embeddings[i] = [val / magnitude for val in average] + _cached_empty_embedding = average_embedded["data"][0]["embedding"] + return _cached_empty_embedding - return embeddings + return [e if e is not None else await empty_embedding() for e in embeddings] def embed_documents( self, texts: List[str], chunk_size: Optional[int] = 0