This commit is contained in:
Eugene Yurtsev
2023-07-26 14:21:49 -04:00
parent 64c38d0fa1
commit b06c2ea366

View File

@@ -7,13 +7,19 @@ from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatMessage,
ChatMessagePromptTemplate,
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
def create_messages() -> List[BaseMessagePromptTemplate]:
@@ -133,7 +139,9 @@ def test_chat_prompt_template_from_messages() -> None:
def test_chat_prompt_template_with_messages() -> None:
messages = create_messages() + [HumanMessage(content="foo")]
messages: List[BaseMessagePromptTemplate] = create_messages() + [
HumanMessage(content="foo")
]
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
assert sorted(chat_prompt_template.input_variables) == sorted(
["context", "foo", "bar"]
@@ -175,7 +183,7 @@ def test_chat_valid_with_partial_variables() -> None:
input_variables=["question", "context"],
partial_variables={"formatins": "some structure"},
)
assert set(prompt.input_variables) == set(["question", "context"])
assert set(prompt.input_variables) == {"question", "context"}
assert prompt.partial_variables == {"formatins": "some structure"}
@@ -188,5 +196,25 @@ def test_chat_valid_infer_variables() -> None:
prompt = ChatPromptTemplate(
messages=messages, partial_variables={"formatins": "some structure"}
)
assert set(prompt.input_variables) == set(["question", "context"])
assert set(prompt.input_variables) == {"question", "context"}
assert prompt.partial_variables == {"formatins": "some structure"}
def test_chat_from_role_strings() -> None:
"""Test instantiation of chat template from role strings."""
template = ChatPromptTemplate.from_role_strings(
[
("system", "You are a bot."),
("ai", "hello!"),
("human", "{question}"),
("other", "{quack}"),
]
)
messages = template.format_messages(question="How are you?", quack="duck")
assert messages == [
SystemMessage(content="You are a bot."),
AIMessage(content="hello!"),
HumanMessage(content="How are you?"),
ChatMessage(content="duck", role="other"),
]