diff --git a/libs/langchain/langchain/chat_models/bedrock.py b/libs/langchain/langchain/chat_models/bedrock.py
index 642ef76cdf4..34bb97f0f79 100644
--- a/libs/langchain/langchain/chat_models/bedrock.py
+++ b/libs/langchain/langchain/chat_models/bedrock.py
@@ -5,6 +5,7 @@ from langchain.callbacks.manager import (
)
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
from langchain.chat_models.base import BaseChatModel
+from langchain.chat_models.meta import convert_messages_to_prompt_llama
from langchain.llms.bedrock import BedrockBase
from langchain.pydantic_v1 import Extra
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
@@ -26,6 +27,8 @@ class ChatPromptAdapter:
) -> str:
if provider == "anthropic":
prompt = convert_messages_to_prompt_anthropic(messages=messages)
+ if provider == "meta":
+ prompt = convert_messages_to_prompt_llama(messages=messages)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."
diff --git a/libs/langchain/langchain/chat_models/meta.py b/libs/langchain/langchain/chat_models/meta.py
new file mode 100644
index 00000000000..c087ee2b1d2
--- /dev/null
+++ b/libs/langchain/langchain/chat_models/meta.py
@@ -0,0 +1,29 @@
+from typing import List
+
+from langchain.schema.messages import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+
+
+def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
+ if isinstance(message, ChatMessage):
+ message_text = f"\n\n{message.role.capitalize()}: {message.content}"
+ elif isinstance(message, HumanMessage):
+ message_text = f"[INST] {message.content} [/INST]"
+ elif isinstance(message, AIMessage):
+ message_text = f"{message.content}"
+ elif isinstance(message, SystemMessage):
+ message_text = f"<> {message.content} <>"
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ return message_text
+
+
+def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str:
+ return "\n".join(
+ [_convert_one_message_to_text_llama(message) for message in messages]
+ )
diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py
index a7fe68b0421..e21bd4d088b 100644
--- a/libs/langchain/langchain/llms/bedrock.py
+++ b/libs/langchain/langchain/llms/bedrock.py
@@ -72,6 +72,7 @@ class LLMInputOutputAdapter:
"anthropic": "completion",
"amazon": "outputText",
"cohere": "text",
+ "meta": "generation",
}
@classmethod
@@ -81,7 +82,7 @@ class LLMInputOutputAdapter:
input_body = {**model_kwargs}
if provider == "anthropic":
input_body["prompt"] = _human_assistant_format(prompt)
- elif provider == "ai21" or provider == "cohere":
+ elif provider in ("ai21", "cohere", "meta"):
input_body["prompt"] = prompt
elif provider == "amazon":
input_body = dict()
@@ -107,6 +108,8 @@ class LLMInputOutputAdapter:
return response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
return response_body.get("generations")[0].get("text")
+ elif provider == "meta":
+ return response_body.get("generation")
else:
return response_body.get("results")[0].get("outputText")
diff --git a/libs/langchain/tests/unit_tests/chat_models/test_bedrock.py b/libs/langchain/tests/unit_tests/chat_models/test_bedrock.py
new file mode 100644
index 00000000000..e92caf5c816
--- /dev/null
+++ b/libs/langchain/tests/unit_tests/chat_models/test_bedrock.py
@@ -0,0 +1,30 @@
+"""Test Anthropic Chat API wrapper."""
+from typing import List
+
+import pytest
+
+from langchain.chat_models.meta import convert_messages_to_prompt_llama
+from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
+
+
+@pytest.mark.parametrize(
+ ("messages", "expected"),
+ [
+ ([HumanMessage(content="Hello")], "[INST] Hello [/INST]"),
+ (
+ [HumanMessage(content="Hello"), AIMessage(content="Answer:")],
+ "[INST] Hello [/INST]\nAnswer:",
+ ),
+ (
+ [
+ SystemMessage(content="You're an assistant"),
+ HumanMessage(content="Hello"),
+ AIMessage(content="Answer:"),
+ ],
+ "<> You're an assistant <>\n[INST] Hello [/INST]\nAnswer:",
+ ),
+ ],
+)
+def test_formatting(messages: List[BaseMessage], expected: str) -> None:
+ result = convert_messages_to_prompt_llama(messages)
+ assert result == expected