core[patch]: ChatPromptTemplate.init same as ChatPromptTemplate.from_… (#24486)

This commit is contained in:
Bagatur 2024-07-26 10:48:39 -07:00 committed by GitHub
parent cc451effd1
commit ad7581751f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 101 additions and 39 deletions

View File

@ -820,11 +820,17 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Examples: Examples:
.. versionchanged:: 0.2.24
You can pass any Message-like formats supported by
``ChatPromptTemplate.from_messages()`` directly to ``ChatPromptTemplate()``
init.
.. code-block:: python .. code-block:: python
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages([ template = ChatPromptTemplate([
("system", "You are a helpful AI bot. Your name is {name}."), ("system", "You are a helpful AI bot. Your name is {name}."),
("human", "Hello, how are you doing?"), ("human", "Hello, how are you doing?"),
("ai", "I'm doing well, thanks!"), ("ai", "I'm doing well, thanks!"),
@ -855,7 +861,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
# you can initialize the template with a MessagesPlaceholder # you can initialize the template with a MessagesPlaceholder
# either using the class directly or with the shorthand tuple syntax: # either using the class directly or with the shorthand tuple syntax:
template = ChatPromptTemplate.from_messages([ template = ChatPromptTemplate([
("system", "You are a helpful AI bot."), ("system", "You are a helpful AI bot."),
# Means the template will receive an optional list of messages under # Means the template will receive an optional list of messages under
# the "conversation" key # the "conversation" key
@ -897,7 +903,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages([ template = ChatPromptTemplate([
("system", "You are a helpful AI bot. Your name is Carl."), ("system", "You are a helpful AI bot. Your name is Carl."),
("human", "{user_input}"), ("human", "{user_input}"),
]) ])
@ -921,6 +927,86 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
validate_template: bool = False validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
def __init__(
self,
messages: Sequence[MessageLikeRepresentation],
*,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
**kwargs: Any,
) -> None:
"""Create a chat prompt template from a variety of message formats.
Args:
messages: sequence of message representations.
A message can be represented using the following formats:
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of
(message type, template); e.g., ("human", "{user_input}"),
(4) 2-tuple of (message class, template), (4) a string which is
shorthand for ("human", template); e.g., "{user_input}".
template_format: format of the template. Defaults to "f-string".
input_variables: A list of the names of the variables whose values are
required as inputs to the prompt.
optional_variables: A list of the names of the variables that are optional
in the prompt.
partial_variables: A dictionary of the partial variables the prompt
template carries. Partial variables populate the template so that you
don't need to pass them in every time you call the prompt.
validate_template: Whether to validate the template.
input_types: A dictionary of the types of the variables the prompt template
expects. If not provided, all variables are assumed to be strings.
Returns:
A chat prompt template.
Examples:
Instantiation from a list of message templates:
.. code-block:: python
template = ChatPromptTemplate([
("human", "Hello, how are you?"),
("ai", "I'm doing well, thanks!"),
("human", "That's good to hear."),
])
Instantiation from mixed message formats:
.. code-block:: python
template = ChatPromptTemplate([
SystemMessage(content="hello"),
("human", "Hello, how are you?"),
])
"""
_messages = [
_convert_to_message(message, template_format) for message in messages
]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
optional_variables: Set[str] = set()
partial_vars: Dict[str, Any] = {}
for _message in _messages:
if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = []
optional_variables.add(_message.variable_name)
elif isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
kwargs = {
**dict(
input_variables=sorted(input_vars),
optional_variables=sorted(optional_variables),
partial_variables=partial_vars,
),
**kwargs,
}
cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
@ -1097,29 +1183,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Returns: Returns:
a chat prompt template. a chat prompt template.
""" """
_messages = [ return cls(messages, template_format=template_format)
_convert_to_message(message, template_format) for message in messages
]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
optional_variables: Set[str] = set()
partial_vars: Dict[str, Any] = {}
for _message in _messages:
if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = []
optional_variables.add(_message.variable_name)
elif isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(
input_variables=sorted(input_vars),
optional_variables=sorted(optional_variables),
messages=_messages,
partial_variables=partial_vars,
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the chat template into a list of finalized messages. """Format the chat template into a list of finalized messages.

View File

@ -438,7 +438,7 @@ def test_chat_prompt_template_indexing() -> None:
message1 = SystemMessage(content="foo") message1 = SystemMessage(content="foo")
message2 = HumanMessage(content="bar") message2 = HumanMessage(content="bar")
message3 = HumanMessage(content="baz") message3 = HumanMessage(content="baz")
template = ChatPromptTemplate.from_messages([message1, message2, message3]) template = ChatPromptTemplate([message1, message2, message3])
assert template[0] == message1 assert template[0] == message1
assert template[1] == message2 assert template[1] == message2
@ -453,7 +453,7 @@ def test_chat_prompt_template_append_and_extend() -> None:
message1 = SystemMessage(content="foo") message1 = SystemMessage(content="foo")
message2 = HumanMessage(content="bar") message2 = HumanMessage(content="bar")
message3 = HumanMessage(content="baz") message3 = HumanMessage(content="baz")
template = ChatPromptTemplate.from_messages([message1]) template = ChatPromptTemplate([message1])
template.append(message2) template.append(message2)
template.append(message3) template.append(message3)
assert len(template) == 3 assert len(template) == 3
@ -480,7 +480,7 @@ def test_convert_to_message_is_strict() -> None:
def test_chat_message_partial() -> None: def test_chat_message_partial() -> None:
template = ChatPromptTemplate.from_messages( template = ChatPromptTemplate(
[ [
("system", "You are an AI assistant named {name}."), ("system", "You are an AI assistant named {name}."),
("human", "Hi I'm {user}"), ("human", "Hi I'm {user}"),
@ -734,14 +734,14 @@ def test_messages_placeholder_with_max() -> None:
def test_chat_prompt_message_placeholder_partial() -> None: def test_chat_prompt_message_placeholder_partial() -> None:
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) prompt = ChatPromptTemplate([MessagesPlaceholder("history")])
prompt = prompt.partial(history=[("system", "foo")]) prompt = prompt.partial(history=[("system", "foo")])
assert prompt.format_messages() == [SystemMessage(content="foo")] assert prompt.format_messages() == [SystemMessage(content="foo")]
assert prompt.format_messages(history=[("system", "bar")]) == [ assert prompt.format_messages(history=[("system", "bar")]) == [
SystemMessage(content="bar") SystemMessage(content="bar")
] ]
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate(
[ [
MessagesPlaceholder("history", optional=True), MessagesPlaceholder("history", optional=True),
] ]
@ -752,7 +752,7 @@ def test_chat_prompt_message_placeholder_partial() -> None:
def test_chat_prompt_message_placeholder_tuple() -> None: def test_chat_prompt_message_placeholder_tuple() -> None:
prompt = ChatPromptTemplate.from_messages([("placeholder", "{convo}")]) prompt = ChatPromptTemplate([("placeholder", "{convo}")])
assert prompt.format_messages(convo=[("user", "foo")]) == [ assert prompt.format_messages(convo=[("user", "foo")]) == [
HumanMessage(content="foo") HumanMessage(content="foo")
] ]
@ -760,9 +760,7 @@ def test_chat_prompt_message_placeholder_tuple() -> None:
assert prompt.format_messages() == [] assert prompt.format_messages() == []
# Is optional = True # Is optional = True
optional_prompt = ChatPromptTemplate.from_messages( optional_prompt = ChatPromptTemplate([("placeholder", ["{convo}", False])])
[("placeholder", ["{convo}", False])]
)
assert optional_prompt.format_messages(convo=[("user", "foo")]) == [ assert optional_prompt.format_messages(convo=[("user", "foo")]) == [
HumanMessage(content="foo") HumanMessage(content="foo")
] ]
@ -771,7 +769,7 @@ def test_chat_prompt_message_placeholder_tuple() -> None:
async def test_messages_prompt_accepts_list() -> None: async def test_messages_prompt_accepts_list() -> None:
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) prompt = ChatPromptTemplate([MessagesPlaceholder("history")])
value = prompt.invoke([("user", "Hi there")]) # type: ignore value = prompt.invoke([("user", "Hi there")]) # type: ignore
assert value.to_messages() == [HumanMessage(content="Hi there")] assert value.to_messages() == [HumanMessage(content="Hi there")]
@ -779,7 +777,7 @@ async def test_messages_prompt_accepts_list() -> None:
assert value.to_messages() == [HumanMessage(content="Hi there")] assert value.to_messages() == [HumanMessage(content="Hi there")]
# Assert still raises a nice error # Assert still raises a nice error
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate(
[("system", "You are a {foo}"), MessagesPlaceholder("history")] [("system", "You are a {foo}"), MessagesPlaceholder("history")]
) )
with pytest.raises(TypeError): with pytest.raises(TypeError):
@ -790,7 +788,7 @@ async def test_messages_prompt_accepts_list() -> None:
def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
prompt_all_required = ChatPromptTemplate.from_messages( prompt_all_required = ChatPromptTemplate(
messages=[MessagesPlaceholder("history", optional=False), ("user", "${input}")] messages=[MessagesPlaceholder("history", optional=False), ("user", "${input}")]
) )
assert set(prompt_all_required.input_variables) == {"input", "history"} assert set(prompt_all_required.input_variables) == {"input", "history"}
@ -798,7 +796,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
prompt_all_required.input_schema(input="") prompt_all_required.input_schema(input="")
assert prompt_all_required.input_schema.schema() == snapshot(name="required") assert prompt_all_required.input_schema.schema() == snapshot(name="required")
prompt_optional = ChatPromptTemplate.from_messages( prompt_optional = ChatPromptTemplate(
messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")] messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")]
) )
# input variables only lists required variables # input variables only lists required variables