mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 21:50:25 +00:00
Add llama2-13b-chat-v1 support to chat_models.BedrockChat
(#13403)
Hi 👋 We are working with Llama2 on Bedrock, and would like to add it to Langchain. We saw a [pull request](https://github.com/langchain-ai/langchain/pull/13322) to add it to the `llm.Bedrock` class, but since it concerns a chat model, we would like to add it to `BedrockChat` as well. - **Description:** Add support for Llama2 to `BedrockChat` in `chat_models` - **Issue:** the issue # it fixes (if applicable) [#13316](https://github.com/langchain-ai/langchain/issues/13316) - **Dependencies:** any dependencies required for this change `None` - **Tag maintainer:** / - **Twitter handle:** `@SimonBockaert @WouterDurnez` --------- Co-authored-by: wouter.durnez <wouter.durnez@showpad.com> Co-authored-by: Simon Bockaert <simon.bockaert@showpad.com>
This commit is contained in:
parent
a93616e972
commit
ef7802b325
@ -5,6 +5,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.chat_models.meta import convert_messages_to_prompt_llama
|
||||
from langchain.llms.bedrock import BedrockBase
|
||||
from langchain.pydantic_v1 import Extra
|
||||
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
@ -26,6 +27,8 @@ class ChatPromptAdapter:
|
||||
) -> str:
|
||||
if provider == "anthropic":
|
||||
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||
if provider == "meta":
|
||||
prompt = convert_messages_to_prompt_llama(messages=messages)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Provider {provider} model does not support chat."
|
||||
|
29
libs/langchain/langchain/chat_models/meta.py
Normal file
29
libs/langchain/langchain/chat_models/meta.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"[INST] {message.content} [/INST]"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"<<SYS>> {message.content} <</SYS>>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
|
||||
def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str:
|
||||
return "\n".join(
|
||||
[_convert_one_message_to_text_llama(message) for message in messages]
|
||||
)
|
@ -72,6 +72,7 @@ class LLMInputOutputAdapter:
|
||||
"anthropic": "completion",
|
||||
"amazon": "outputText",
|
||||
"cohere": "text",
|
||||
"meta": "generation",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -81,7 +82,7 @@ class LLMInputOutputAdapter:
|
||||
input_body = {**model_kwargs}
|
||||
if provider == "anthropic":
|
||||
input_body["prompt"] = _human_assistant_format(prompt)
|
||||
elif provider == "ai21" or provider == "cohere":
|
||||
elif provider in ("ai21", "cohere", "meta"):
|
||||
input_body["prompt"] = prompt
|
||||
elif provider == "amazon":
|
||||
input_body = dict()
|
||||
@ -107,6 +108,8 @@ class LLMInputOutputAdapter:
|
||||
return response_body.get("completions")[0].get("data").get("text")
|
||||
elif provider == "cohere":
|
||||
return response_body.get("generations")[0].get("text")
|
||||
elif provider == "meta":
|
||||
return response_body.get("generation")
|
||||
else:
|
||||
return response_body.get("results")[0].get("outputText")
|
||||
|
||||
|
30
libs/langchain/tests/unit_tests/chat_models/test_bedrock.py
Normal file
30
libs/langchain/tests/unit_tests/chat_models/test_bedrock.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Test Anthropic Chat API wrapper."""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models.meta import convert_messages_to_prompt_llama
|
||||
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("messages", "expected"),
|
||||
[
|
||||
([HumanMessage(content="Hello")], "[INST] Hello [/INST]"),
|
||||
(
|
||||
[HumanMessage(content="Hello"), AIMessage(content="Answer:")],
|
||||
"[INST] Hello [/INST]\nAnswer:",
|
||||
),
|
||||
(
|
||||
[
|
||||
SystemMessage(content="You're an assistant"),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Answer:"),
|
||||
],
|
||||
"<<SYS>> You're an assistant <</SYS>>\n[INST] Hello [/INST]\nAnswer:",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_formatting(messages: List[BaseMessage], expected: str) -> None:
|
||||
result = convert_messages_to_prompt_llama(messages)
|
||||
assert result == expected
|
Loading…
Reference in New Issue
Block a user