mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 12:38:45 +00:00
Added prompt wrapping for Claude with Bedrock (#11090)
- **Description:** Prompt wrapping requirements have been implemented on the service side of AWS Bedrock for the Anthropic Claude models to provide parity between Anthropic's offering and Bedrock's offering. This overnight change broke most existing implementations of Claude, Bedrock and Langchain. This PR just steals the the Anthropic LLM implementation to enforce alias/role wrapping and implements it in the existing mechanism for building the request body. This has also been tested to fix the chat_model implementation as well. Happy to answer any further questions or make changes where necessary to get things patched and up to PyPi ASAP, TY. - **Issue:** No issue opened at the moment, though will update when these roll in. - **Dependencies:** None --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
b87cc8b31e
commit
23065f54c0
@ -8,6 +8,52 @@ from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
|
||||
HUMAN_PROMPT = "\n\nHuman:"
|
||||
ASSISTANT_PROMPT = "\n\nAssistant:"
|
||||
ALTERNATION_ERROR = (
|
||||
"Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'."
|
||||
)
|
||||
|
||||
|
||||
def _add_newlines_before_ha(input_text: str) -> str:
|
||||
new_text = input_text
|
||||
for word in ["Human:", "Assistant:"]:
|
||||
new_text = new_text.replace(word, "\n\n" + word)
|
||||
for i in range(2):
|
||||
new_text = new_text.replace("\n\n\n" + word, "\n\n" + word)
|
||||
return new_text
|
||||
|
||||
|
||||
def _human_assistant_format(input_text: str) -> str:
|
||||
if input_text.count("Human:") == 0 or (
|
||||
input_text.find("Human:") > input_text.find("Assistant:")
|
||||
and "Assistant:" in input_text
|
||||
):
|
||||
input_text = HUMAN_PROMPT + " " + input_text # SILENT CORRECTION
|
||||
if input_text.count("Assistant:") == 0:
|
||||
input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION
|
||||
if input_text[: len("Human:")] == "Human:":
|
||||
input_text = "\n\n" + input_text
|
||||
input_text = _add_newlines_before_ha(input_text)
|
||||
count = 0
|
||||
# track alternation
|
||||
for i in range(len(input_text)):
|
||||
if input_text[i : i + len(HUMAN_PROMPT)] == HUMAN_PROMPT:
|
||||
if count % 2 == 0:
|
||||
count += 1
|
||||
else:
|
||||
raise ValueError(ALTERNATION_ERROR)
|
||||
if input_text[i : i + len(ASSISTANT_PROMPT)] == ASSISTANT_PROMPT:
|
||||
if count % 2 == 1:
|
||||
count += 1
|
||||
else:
|
||||
raise ValueError(ALTERNATION_ERROR)
|
||||
|
||||
if count % 2 == 1: # Only saw Human, no Assistant
|
||||
input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION
|
||||
|
||||
return input_text
|
||||
|
||||
|
||||
class LLMInputOutputAdapter:
|
||||
"""Adapter class to prepare the inputs from Langchain to a format
|
||||
@ -26,7 +72,9 @@ class LLMInputOutputAdapter:
|
||||
cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
input_body = {**model_kwargs}
|
||||
if provider == "anthropic" or provider == "ai21":
|
||||
if provider == "anthropic":
|
||||
input_body["prompt"] = _human_assistant_format(prompt)
|
||||
elif provider == "ai21":
|
||||
input_body["prompt"] = prompt
|
||||
elif provider == "amazon":
|
||||
input_body = dict()
|
||||
|
252
libs/langchain/tests/unit_tests/llms/test_bedrock.py
Normal file
252
libs/langchain/tests/unit_tests/llms/test_bedrock.py
Normal file
@ -0,0 +1,252 @@
|
||||
import pytest
|
||||
|
||||
from langchain.llms.bedrock import ALTERNATION_ERROR, _human_assistant_format
|
||||
|
||||
TEST_CASES = {
|
||||
"""Hey""": """
|
||||
|
||||
Human: Hey
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""": """
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""",
|
||||
"""Human: Hello
|
||||
|
||||
Assistant:""": """
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
Human: Hello
|
||||
|
||||
Assistant:""": """
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
|
||||
Human: Human: Hello
|
||||
|
||||
Assistant:""": (
|
||||
"Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'."
|
||||
),
|
||||
"""Human: Hello
|
||||
|
||||
Assistant: Hello
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""": """
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant: Hello
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant: Hello
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""": """
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant: Hello
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant: Hello
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant: Hello
|
||||
|
||||
Assistant: Hello""": ALTERNATION_ERROR,
|
||||
"""
|
||||
|
||||
Human: Hi.
|
||||
|
||||
Assistant: Hi.
|
||||
|
||||
Human: Hi.
|
||||
|
||||
Human: Hi.
|
||||
|
||||
Assistant:""": ALTERNATION_ERROR,
|
||||
"""
|
||||
Human: Hello""": """
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
|
||||
Human: Hello
|
||||
Hello
|
||||
|
||||
Assistant""": """
|
||||
|
||||
Human: Hello
|
||||
Hello
|
||||
|
||||
Assistant
|
||||
|
||||
Assistant:""",
|
||||
"""Hello
|
||||
|
||||
Assistant:""": """
|
||||
|
||||
Human: Hello
|
||||
|
||||
Assistant:""",
|
||||
"""Hello
|
||||
|
||||
Human: Hello
|
||||
|
||||
""": """Hello
|
||||
|
||||
Human: Hello
|
||||
|
||||
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
|
||||
Human: Assistant: Hello""": """
|
||||
|
||||
Human:
|
||||
|
||||
Assistant: Hello""",
|
||||
"""
|
||||
|
||||
Human: Human
|
||||
|
||||
Assistant: Assistant
|
||||
|
||||
Human: Assistant
|
||||
|
||||
Assistant: Human""": """
|
||||
|
||||
Human: Human
|
||||
|
||||
Assistant: Assistant
|
||||
|
||||
Human: Assistant
|
||||
|
||||
Assistant: Human""",
|
||||
"""
|
||||
Assistant: Hello there, your name is:
|
||||
|
||||
Human.
|
||||
|
||||
Human: Hello there, your name is:
|
||||
|
||||
Assistant.""": """
|
||||
|
||||
Human:
|
||||
|
||||
Assistant: Hello there, your name is:
|
||||
|
||||
Human.
|
||||
|
||||
Human: Hello there, your name is:
|
||||
|
||||
Assistant.
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
|
||||
Human: Human: Hi
|
||||
|
||||
Assistant: Hi""": ALTERNATION_ERROR,
|
||||
"""Human: Hi
|
||||
|
||||
Human: Hi""": ALTERNATION_ERROR,
|
||||
"""
|
||||
|
||||
Assistant: Hi
|
||||
|
||||
Human: Hi""": """
|
||||
|
||||
Human:
|
||||
|
||||
Assistant: Hi
|
||||
|
||||
Human: Hi
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
|
||||
Human: Hi
|
||||
|
||||
Assistant: Yo
|
||||
|
||||
Human: Hey
|
||||
|
||||
Assistant: Sup
|
||||
|
||||
Human: Hi
|
||||
|
||||
Assistant: Hi
|
||||
Human: Hi
|
||||
Assistant:""": """
|
||||
|
||||
Human: Hi
|
||||
|
||||
Assistant: Yo
|
||||
|
||||
Human: Hey
|
||||
|
||||
Assistant: Sup
|
||||
|
||||
Human: Hi
|
||||
|
||||
Assistant: Hi
|
||||
|
||||
Human: Hi
|
||||
|
||||
Assistant:""",
|
||||
"""
|
||||
|
||||
Hello.
|
||||
|
||||
Human: Hello.
|
||||
|
||||
Assistant:""": """
|
||||
|
||||
Hello.
|
||||
|
||||
Human: Hello.
|
||||
|
||||
Assistant:""",
|
||||
}
|
||||
|
||||
|
||||
def test__human_assistant_format() -> None:
|
||||
for input_text, expected_output in TEST_CASES.items():
|
||||
if expected_output == ALTERNATION_ERROR:
|
||||
with pytest.raises(ValueError):
|
||||
_human_assistant_format(input_text)
|
||||
else:
|
||||
output = _human_assistant_format(input_text)
|
||||
assert output == expected_output
|
Loading…
Reference in New Issue
Block a user