diff --git a/libs/langchain/langchain/embeddings/bedrock.py b/libs/langchain/langchain/embeddings/bedrock.py index b65200901bb..55064f15544 100644 --- a/libs/langchain/langchain/embeddings/bedrock.py +++ b/libs/langchain/langchain/embeddings/bedrock.py @@ -112,20 +112,36 @@ class BedrockEmbeddings(BaseModel, Embeddings): """Call out to Bedrock embedding endpoint.""" # replace newlines, which can negatively affect performance. text = text.replace(os.linesep, " ") - _model_kwargs = self.model_kwargs or {} - input_body = {**_model_kwargs, "inputText": text} + # format input body for provider + provider = self.model_id.split(".")[0] + _model_kwargs = self.model_kwargs or {} + input_body = {**_model_kwargs} + if provider == "cohere": + if "input_type" not in input_body.keys(): + input_body["input_type"] = "search_document" + input_body["texts"] = [text] + else: + # includes common provider == "amazon" + input_body["inputText"] = text body = json.dumps(input_body) try: + # invoke bedrock API response = self.client.invoke_model( body=body, modelId=self.model_id, accept="application/json", contentType="application/json", ) + + # format output based on provider response_body = json.loads(response.get("body").read()) - return response_body.get("embedding") + if provider == "cohere": + return response_body.get("embeddings")[0] + else: + # includes common provider == "amazon" + return response_body.get("embedding") except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e}")