Bedrock embeddings async methods (#9024)

## Description
This PR adds the `aembed_query` and `aembed_documents` async methods for
improving the embeddings generation for large documents. The
implementation uses asyncio tasks and gather to achieve concurrency as
there is no bedrock async API in boto3.

### Maintainers
@agola11 
@aarora79  

### Open questions
To avoid throttling from the Bedrock API, should there be an option to
limit the concurrency of the calls?
This commit is contained in:
Piyush Jain 2023-08-10 14:21:03 -07:00 committed by GitHub
parent 67ca187560
commit 8eea46ed0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 15 deletions

View File

@ -31,12 +31,11 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.llms.bedrock import Bedrock\n", "from langchain.llms import Bedrock\n",
"\n", "\n",
"llm = Bedrock(\n", "llm = Bedrock(\n",
" credentials_profile_name=\"bedrock-admin\",\n", " credentials_profile_name=\"bedrock-admin\",\n",
" model_id=\"amazon.titan-tg1-large\",\n", " model_id=\"amazon.titan-tg1-large\"\n",
" endpoint_url=\"custom_endpoint_url\",\n",
")" ")"
] ]
}, },

View File

@ -20,7 +20,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"id": "282239c8-e03a-4abc-86c1-ca6120231a20", "id": "282239c8-e03a-4abc-86c1-ca6120231a20",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -28,7 +28,7 @@
"from langchain.embeddings import BedrockEmbeddings\n", "from langchain.embeddings import BedrockEmbeddings\n",
"\n", "\n",
"embeddings = BedrockEmbeddings(\n", "embeddings = BedrockEmbeddings(\n",
" credentials_profile_name=\"bedrock-admin\", endpoint_url=\"custom_endpoint_url\"\n", " credentials_profile_name=\"bedrock-admin\", region_name=\"us-east-1\"\n",
")" ")"
] ]
}, },
@ -49,7 +49,29 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"embeddings.embed_documents([\"This is a content of the document\"])" "embeddings.embed_documents([\"This is a content of the document\", \"This is another document\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f6b364d",
"metadata": {},
"outputs": [],
"source": [
"# async embed query\n",
"await embeddings.aembed_query(\"This is a content of the document\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9240a5a",
"metadata": {},
"outputs": [],
"source": [
"# async embed documents\n",
"await embeddings.aembed_documents([\"This is a content of the document\", \"This is another document\"])"
] ]
} }
], ],
@ -69,7 +91,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.11" "version": "3.9.13"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,5 +1,7 @@
import asyncio
import json import json
import os import os
from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -128,17 +130,11 @@ class BedrockEmbeddings(BaseModel, Embeddings):
except Exception as e: except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}") raise ValueError(f"Error raised by inference endpoint: {e}")
def embed_documents( def embed_documents(self, texts: List[str]) -> List[List[float]]:
self, texts: List[str], chunk_size: int = 1
) -> List[List[float]]:
"""Compute doc embeddings using a Bedrock model. """Compute doc embeddings using a Bedrock model.
Args: Args:
texts: The list of texts to embed. texts: The list of texts to embed
chunk_size: Bedrock currently only allows single string
inputs, so chunk size is always 1. This input is here
only for compatibility with the embeddings interface.
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
@ -159,3 +155,31 @@ class BedrockEmbeddings(BaseModel, Embeddings):
Embeddings for the text. Embeddings for the text.
""" """
return self._embedding_func(text) return self._embedding_func(text)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous compute query embeddings using a Bedrock model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.embed_query, text)
)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous compute doc embeddings using a Bedrock model.
Args:
texts: The list of texts to embed
Returns:
List of embeddings, one for each text.
"""
result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
return list(result)