Add llama2-13b-chat-v1 support to chat_models.BedrockChat (#13403)

Hi 👋 We are working with Llama2 on Bedrock, and would like to add it to
Langchain. We saw a [pull
request](https://github.com/langchain-ai/langchain/pull/13322) to add it
to the `llm.Bedrock` class, but since it concerns a chat model, we would
like to add it to `BedrockChat` as well.

- **Description:** Add support for Llama2 to `BedrockChat` in
`chat_models`
- **Issue:** the issue # it fixes (if applicable)
[#13316](https://github.com/langchain-ai/langchain/issues/13316)
  - **Dependencies:** any dependencies required for this change `None`
  - **Tag maintainer:** /
  - **Twitter handle:** `@SimonBockaert @WouterDurnez`

---------

Co-authored-by: wouter.durnez <wouter.durnez@showpad.com>
Co-authored-by: Simon Bockaert <simon.bockaert@showpad.com>
This commit is contained in:
Wouter Durnez 2023-11-20 03:44:58 +01:00 committed by GitHub
parent a93616e972
commit ef7802b325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 1 deletions

View File

@ -5,6 +5,7 @@ from langchain.callbacks.manager import (
)
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
from langchain.chat_models.base import BaseChatModel
from langchain.chat_models.meta import convert_messages_to_prompt_llama
from langchain.llms.bedrock import BedrockBase
from langchain.pydantic_v1 import Extra
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
@ -26,6 +27,8 @@ class ChatPromptAdapter:
) -> str:
if provider == "anthropic":
prompt = convert_messages_to_prompt_anthropic(messages=messages)
if provider == "meta":
prompt = convert_messages_to_prompt_llama(messages=messages)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."

View File

@ -0,0 +1,29 @@
from typing import List
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"[INST] {message.content} [/INST]"
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<<SYS>> {message.content} <</SYS>>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str:
return "\n".join(
[_convert_one_message_to_text_llama(message) for message in messages]
)

View File

@ -72,6 +72,7 @@ class LLMInputOutputAdapter:
"anthropic": "completion",
"amazon": "outputText",
"cohere": "text",
"meta": "generation",
}
@classmethod
@ -81,7 +82,7 @@ class LLMInputOutputAdapter:
input_body = {**model_kwargs}
if provider == "anthropic":
input_body["prompt"] = _human_assistant_format(prompt)
elif provider == "ai21" or provider == "cohere":
elif provider in ("ai21", "cohere", "meta"):
input_body["prompt"] = prompt
elif provider == "amazon":
input_body = dict()
@ -107,6 +108,8 @@ class LLMInputOutputAdapter:
return response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
return response_body.get("generations")[0].get("text")
elif provider == "meta":
return response_body.get("generation")
else:
return response_body.get("results")[0].get("outputText")

View File

@ -0,0 +1,30 @@
"""Test Anthropic Chat API wrapper."""
from typing import List
import pytest
from langchain.chat_models.meta import convert_messages_to_prompt_llama
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
@pytest.mark.parametrize(
("messages", "expected"),
[
([HumanMessage(content="Hello")], "[INST] Hello [/INST]"),
(
[HumanMessage(content="Hello"), AIMessage(content="Answer:")],
"[INST] Hello [/INST]\nAnswer:",
),
(
[
SystemMessage(content="You're an assistant"),
HumanMessage(content="Hello"),
AIMessage(content="Answer:"),
],
"<<SYS>> You're an assistant <</SYS>>\n[INST] Hello [/INST]\nAnswer:",
),
],
)
def test_formatting(messages: List[BaseMessage], expected: str) -> None:
result = convert_messages_to_prompt_llama(messages)
assert result == expected