diff --git a/libs/langchain/tests/unit_tests/prompts/test_chat.py b/libs/langchain/tests/unit_tests/prompts/test_chat.py index d355b90dfc4..bf0d0d8f1ba 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_chat.py +++ b/libs/langchain/tests/unit_tests/prompts/test_chat.py @@ -7,13 +7,19 @@ from langchain.prompts import PromptTemplate from langchain.prompts.chat import ( AIMessagePromptTemplate, BaseMessagePromptTemplate, + ChatMessage, ChatMessagePromptTemplate, ChatPromptTemplate, ChatPromptValue, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.schema.messages import HumanMessage +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) def create_messages() -> List[BaseMessagePromptTemplate]: @@ -133,7 +139,9 @@ def test_chat_prompt_template_from_messages() -> None: def test_chat_prompt_template_with_messages() -> None: - messages = create_messages() + [HumanMessage(content="foo")] + messages: List[BaseMessagePromptTemplate] = create_messages() + [ + HumanMessage(content="foo") + ] chat_prompt_template = ChatPromptTemplate.from_messages(messages) assert sorted(chat_prompt_template.input_variables) == sorted( ["context", "foo", "bar"] @@ -175,7 +183,7 @@ def test_chat_valid_with_partial_variables() -> None: input_variables=["question", "context"], partial_variables={"formatins": "some structure"}, ) - assert set(prompt.input_variables) == set(["question", "context"]) + assert set(prompt.input_variables) == {"question", "context"} assert prompt.partial_variables == {"formatins": "some structure"} @@ -188,5 +196,25 @@ def test_chat_valid_infer_variables() -> None: prompt = ChatPromptTemplate( messages=messages, partial_variables={"formatins": "some structure"} ) - assert set(prompt.input_variables) == set(["question", "context"]) + assert set(prompt.input_variables) == {"question", "context"} assert prompt.partial_variables == {"formatins": "some structure"} + + +def test_chat_from_role_strings() -> None: + """Test instantiation of chat template from role strings.""" + template = ChatPromptTemplate.from_role_strings( + [ + ("system", "You are a bot."), + ("ai", "hello!"), + ("human", "{question}"), + ("other", "{quack}"), + ] + ) + + messages = template.format_messages(question="How are you?", quack="duck") + assert messages == [ + SystemMessage(content="You are a bot."), + AIMessage(content="hello!"), + HumanMessage(content="How are you?"), + ChatMessage(content="duck", role="other"), + ]