mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
community[patch]: BedrockChat -> Support Titan express as chat model (#15408)
Titan Express model was not supported as a chat model because LangChain messages were not "translated" to a text prompt. Co-authored-by: Guillem Orellana Trullols <guillem.orellana_trullols@siemens.com>
This commit is contained in:
parent
1b9001db47
commit
aad2aa7188
@ -32,6 +32,12 @@ class ChatPromptAdapter:
|
|||||||
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||||
elif provider == "meta":
|
elif provider == "meta":
|
||||||
prompt = convert_messages_to_prompt_llama(messages=messages)
|
prompt = convert_messages_to_prompt_llama(messages=messages)
|
||||||
|
elif provider == "amazon":
|
||||||
|
prompt = convert_messages_to_prompt_anthropic(
|
||||||
|
messages=messages,
|
||||||
|
human_prompt="\n\nUser:",
|
||||||
|
ai_prompt="\n\nBot:",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Provider {provider} model does not support chat."
|
f"Provider {provider} model does not support chat."
|
||||||
|
@ -272,10 +272,12 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.invoke_model(
|
response = self.client.invoke_model(
|
||||||
body=body, modelId=self.model_id, accept=accept, contentType=contentType
|
body=body,
|
||||||
|
modelId=self.model_id,
|
||||||
|
accept=accept,
|
||||||
|
contentType=contentType,
|
||||||
)
|
)
|
||||||
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error raised by bedrock service: {e}").with_traceback(
|
raise ValueError(f"Error raised by bedrock service: {e}").with_traceback(
|
||||||
e.__traceback__
|
e.__traceback__
|
||||||
|
@ -37,17 +37,24 @@ def test_formatting(messages: List[BaseMessage], expected: str) -> None:
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_bedrock() -> None:
|
@pytest.mark.parametrize(
|
||||||
|
"model_id",
|
||||||
|
["anthropic.claude-v2", "amazon.titan-text-express-v1"],
|
||||||
|
)
|
||||||
|
def test_different_models_bedrock(model_id: str) -> None:
|
||||||
|
provider = model_id.split(".")[0]
|
||||||
client = MagicMock()
|
client = MagicMock()
|
||||||
respbody = MagicMock(
|
respbody = MagicMock()
|
||||||
read=MagicMock(
|
if provider == "anthropic":
|
||||||
return_value=MagicMock(
|
respbody.read.return_value = MagicMock(
|
||||||
decode=MagicMock(return_value=b'{"completion":"Hi back"}')
|
decode=MagicMock(return_value=b'{"completion":"Hi back"}'),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
client.invoke_model.return_value = {"body": respbody}
|
||||||
client.invoke_model.return_value = {"body": respbody}
|
elif provider == "amazon":
|
||||||
model = BedrockChat(model_id="anthropic.claude-v2", client=client)
|
respbody.read.return_value = '{"results": [{"outputText": "Hi back"}]}'
|
||||||
|
client.invoke_model.return_value = {"body": respbody}
|
||||||
|
|
||||||
|
model = BedrockChat(model_id=model_id, client=client)
|
||||||
|
|
||||||
# should not throw an error
|
# should not throw an error
|
||||||
model.invoke("hello there")
|
model.invoke("hello there")
|
||||||
|
Loading…
Reference in New Issue
Block a user