Bedrock embeddings async methods (#9024)

## Description
This PR adds the `aembed_query` and `aembed_documents` async methods for
improving the embeddings generation for large documents. The
implementation uses asyncio tasks and gather to achieve concurrency as
there is no bedrock async API in boto3.

### Maintainers
@agola11 
@aarora79  

### Open questions
To avoid throttling from the Bedrock API, should there be an option to
limit the concurrency of the calls?
This commit is contained in:
Piyush Jain
2023-08-10 14:21:03 -07:00
committed by GitHub
parent 67ca187560
commit 8eea46ed0e
3 changed files with 60 additions and 15 deletions

View File

@@ -1,5 +1,7 @@
import asyncio
import json
import os
from functools import partial
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
@@ -128,17 +130,11 @@ class BedrockEmbeddings(BaseModel, Embeddings):
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
def embed_documents(
self, texts: List[str], chunk_size: int = 1
) -> List[List[float]]:
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a Bedrock model.
Args:
texts: The list of texts to embed.
chunk_size: Bedrock currently only allows single string
inputs, so chunk size is always 1. This input is here
only for compatibility with the embeddings interface.
texts: The list of texts to embed
Returns:
List of embeddings, one for each text.
@@ -159,3 +155,31 @@ class BedrockEmbeddings(BaseModel, Embeddings):
Embeddings for the text.
"""
return self._embedding_func(text)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous compute query embeddings using a Bedrock model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.embed_query, text)
)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous compute doc embeddings using a Bedrock model.
Args:
texts: The list of texts to embed
Returns:
List of embeddings, one for each text.
"""
result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
return list(result)