diff --git a/libs/langchain/langchain/chat_models/bedrock.py b/libs/langchain/langchain/chat_models/bedrock.py index 642ef76cdf4..34bb97f0f79 100644 --- a/libs/langchain/langchain/chat_models/bedrock.py +++ b/libs/langchain/langchain/chat_models/bedrock.py @@ -5,6 +5,7 @@ from langchain.callbacks.manager import ( ) from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.meta import convert_messages_to_prompt_llama from langchain.llms.bedrock import BedrockBase from langchain.pydantic_v1 import Extra from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage @@ -26,6 +27,8 @@ class ChatPromptAdapter: ) -> str: if provider == "anthropic": prompt = convert_messages_to_prompt_anthropic(messages=messages) + if provider == "meta": + prompt = convert_messages_to_prompt_llama(messages=messages) else: raise NotImplementedError( f"Provider {provider} model does not support chat." diff --git a/libs/langchain/langchain/chat_models/meta.py b/libs/langchain/langchain/chat_models/meta.py new file mode 100644 index 00000000000..c087ee2b1d2 --- /dev/null +++ b/libs/langchain/langchain/chat_models/meta.py @@ -0,0 +1,29 @@ +from typing import List + +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) + + +def _convert_one_message_to_text_llama(message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"[INST] {message.content} [/INST]" + elif isinstance(message, AIMessage): + message_text = f"{message.content}" + elif isinstance(message, SystemMessage): + message_text = f"<> {message.content} <>" + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + +def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str: + return "\n".join( + [_convert_one_message_to_text_llama(message) for message in messages] + ) diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index a7fe68b0421..e21bd4d088b 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -72,6 +72,7 @@ class LLMInputOutputAdapter: "anthropic": "completion", "amazon": "outputText", "cohere": "text", + "meta": "generation", } @classmethod @@ -81,7 +82,7 @@ class LLMInputOutputAdapter: input_body = {**model_kwargs} if provider == "anthropic": input_body["prompt"] = _human_assistant_format(prompt) - elif provider == "ai21" or provider == "cohere": + elif provider in ("ai21", "cohere", "meta"): input_body["prompt"] = prompt elif provider == "amazon": input_body = dict() @@ -107,6 +108,8 @@ class LLMInputOutputAdapter: return response_body.get("completions")[0].get("data").get("text") elif provider == "cohere": return response_body.get("generations")[0].get("text") + elif provider == "meta": + return response_body.get("generation") else: return response_body.get("results")[0].get("outputText") diff --git a/libs/langchain/tests/unit_tests/chat_models/test_bedrock.py b/libs/langchain/tests/unit_tests/chat_models/test_bedrock.py new file mode 100644 index 00000000000..e92caf5c816 --- /dev/null +++ b/libs/langchain/tests/unit_tests/chat_models/test_bedrock.py @@ -0,0 +1,30 @@ +"""Test Anthropic Chat API wrapper.""" +from typing import List + +import pytest + +from langchain.chat_models.meta import convert_messages_to_prompt_llama +from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage + + +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ([HumanMessage(content="Hello")], "[INST] Hello [/INST]"), + ( + [HumanMessage(content="Hello"), AIMessage(content="Answer:")], + "[INST] Hello [/INST]\nAnswer:", + ), + ( + [ + SystemMessage(content="You're an assistant"), + HumanMessage(content="Hello"), + AIMessage(content="Answer:"), + ], + "<> You're an assistant <>\n[INST] Hello [/INST]\nAnswer:", + ), + ], +) +def test_formatting(messages: List[BaseMessage], expected: str) -> None: + result = convert_messages_to_prompt_llama(messages) + assert result == expected