mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
Bedrock embeddings async methods (#9024)
## Description This PR adds the `aembed_query` and `aembed_documents` async methods for improving the embeddings generation for large documents. The implementation uses asyncio tasks and gather to achieve concurrency as there is no bedrock async API in boto3. ### Maintainers @agola11 @aarora79 ### Open questions To avoid throttling from the Bedrock API, should there be an option to limit the concurrency of the calls?
This commit is contained in:
parent
67ca187560
commit
8eea46ed0e
@ -31,12 +31,11 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.llms.bedrock import Bedrock\n",
|
"from langchain.llms import Bedrock\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = Bedrock(\n",
|
"llm = Bedrock(\n",
|
||||||
" credentials_profile_name=\"bedrock-admin\",\n",
|
" credentials_profile_name=\"bedrock-admin\",\n",
|
||||||
" model_id=\"amazon.titan-tg1-large\",\n",
|
" model_id=\"amazon.titan-tg1-large\"\n",
|
||||||
" endpoint_url=\"custom_endpoint_url\",\n",
|
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -20,7 +20,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 5,
|
||||||
"id": "282239c8-e03a-4abc-86c1-ca6120231a20",
|
"id": "282239c8-e03a-4abc-86c1-ca6120231a20",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -28,7 +28,7 @@
|
|||||||
"from langchain.embeddings import BedrockEmbeddings\n",
|
"from langchain.embeddings import BedrockEmbeddings\n",
|
||||||
"\n",
|
"\n",
|
||||||
"embeddings = BedrockEmbeddings(\n",
|
"embeddings = BedrockEmbeddings(\n",
|
||||||
" credentials_profile_name=\"bedrock-admin\", endpoint_url=\"custom_endpoint_url\"\n",
|
" credentials_profile_name=\"bedrock-admin\", region_name=\"us-east-1\"\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -49,7 +49,29 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"embeddings.embed_documents([\"This is a content of the document\"])"
|
"embeddings.embed_documents([\"This is a content of the document\", \"This is another document\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9f6b364d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# async embed query\n",
|
||||||
|
"await embeddings.aembed_query(\"This is a content of the document\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c9240a5a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# async embed documents\n",
|
||||||
|
"await embeddings.aembed_documents([\"This is a content of the document\", \"This is another document\"])"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -69,7 +91,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.11"
|
"version": "3.9.13"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
@ -128,17 +130,11 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||||
|
|
||||||
def embed_documents(
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
self, texts: List[str], chunk_size: int = 1
|
|
||||||
) -> List[List[float]]:
|
|
||||||
"""Compute doc embeddings using a Bedrock model.
|
"""Compute doc embeddings using a Bedrock model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: The list of texts to embed.
|
texts: The list of texts to embed
|
||||||
chunk_size: Bedrock currently only allows single string
|
|
||||||
inputs, so chunk size is always 1. This input is here
|
|
||||||
only for compatibility with the embeddings interface.
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
@ -159,3 +155,31 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
|||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
return self._embedding_func(text)
|
return self._embedding_func(text)
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
"""Asynchronous compute query embeddings using a Bedrock model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self.embed_query, text)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Asynchronous compute doc embeddings using a Bedrock model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
|
||||||
|
|
||||||
|
return list(result)
|
||||||
|
Loading…
Reference in New Issue
Block a user