diff --git a/docs/extras/integrations/llms/bedrock.ipynb b/docs/extras/integrations/llms/bedrock.ipynb index 56847a00fd4..06ae9f4cee1 100644 --- a/docs/extras/integrations/llms/bedrock.ipynb +++ b/docs/extras/integrations/llms/bedrock.ipynb @@ -31,12 +31,11 @@ }, "outputs": [], "source": [ - "from langchain.llms.bedrock import Bedrock\n", + "from langchain.llms import Bedrock\n", "\n", "llm = Bedrock(\n", " credentials_profile_name=\"bedrock-admin\",\n", - " model_id=\"amazon.titan-tg1-large\",\n", - " endpoint_url=\"custom_endpoint_url\",\n", + " model_id=\"amazon.titan-tg1-large\"\n", ")" ] }, diff --git a/docs/extras/integrations/text_embedding/bedrock.ipynb b/docs/extras/integrations/text_embedding/bedrock.ipynb index a69c99de1dd..7c16cb8ead4 100644 --- a/docs/extras/integrations/text_embedding/bedrock.ipynb +++ b/docs/extras/integrations/text_embedding/bedrock.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "282239c8-e03a-4abc-86c1-ca6120231a20", "metadata": {}, "outputs": [], @@ -28,7 +28,7 @@ "from langchain.embeddings import BedrockEmbeddings\n", "\n", "embeddings = BedrockEmbeddings(\n", - " credentials_profile_name=\"bedrock-admin\", endpoint_url=\"custom_endpoint_url\"\n", + " credentials_profile_name=\"bedrock-admin\", region_name=\"us-east-1\"\n", ")" ] }, @@ -49,7 +49,29 @@ "metadata": {}, "outputs": [], "source": [ - "embeddings.embed_documents([\"This is a content of the document\"])" + "embeddings.embed_documents([\"This is a content of the document\", \"This is another document\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f6b364d", + "metadata": {}, + "outputs": [], + "source": [ + "# async embed query\n", + "await embeddings.aembed_query(\"This is a content of the document\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9240a5a", + "metadata": {}, + "outputs": [], + "source": [ + "# async embed documents\n", + "await embeddings.aembed_documents([\"This is a content of the document\", \"This is another document\"])" ] } ], @@ -69,7 +91,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.9.13" } }, "nbformat": 4, diff --git a/libs/langchain/langchain/embeddings/bedrock.py b/libs/langchain/langchain/embeddings/bedrock.py index 93b3d69bbef..9396b320c4b 100644 --- a/libs/langchain/langchain/embeddings/bedrock.py +++ b/libs/langchain/langchain/embeddings/bedrock.py @@ -1,5 +1,7 @@ +import asyncio import json import os +from functools import partial from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator @@ -128,17 +130,11 @@ class BedrockEmbeddings(BaseModel, Embeddings): except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e}") - def embed_documents( - self, texts: List[str], chunk_size: int = 1 - ) -> List[List[float]]: + def embed_documents(self, texts: List[str]) -> List[List[float]]: """Compute doc embeddings using a Bedrock model. Args: - texts: The list of texts to embed. - chunk_size: Bedrock currently only allows single string - inputs, so chunk size is always 1. This input is here - only for compatibility with the embeddings interface. - + texts: The list of texts to embed Returns: List of embeddings, one for each text. @@ -159,3 +155,31 @@ class BedrockEmbeddings(BaseModel, Embeddings): Embeddings for the text. """ return self._embedding_func(text) + + async def aembed_query(self, text: str) -> List[float]: + """Asynchronous compute query embeddings using a Bedrock model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.embed_query, text) + ) + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronous compute doc embeddings using a Bedrock model. + + Args: + texts: The list of texts to embed + + Returns: + List of embeddings, one for each text. + """ + + result = await asyncio.gather(*[self.aembed_query(text) for text in texts]) + + return list(result)