mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
ChatPromptTemplate: Update doc-strings, update from_role_strings behavior (#8308)
* Update doc-strings in ChatPromptTemplate * Update from_role_strings classmethod to use well known roles
This commit is contained in:
parent
2c2fd9ff13
commit
862e9aed66
@ -220,7 +220,10 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
|||||||
|
|
||||||
|
|
||||||
class ChatPromptValue(PromptValue):
|
class ChatPromptValue(PromptValue):
|
||||||
"""Chat prompt value."""
|
"""Chat prompt value.
|
||||||
|
|
||||||
|
A type of a prompt value that is built from messages.
|
||||||
|
"""
|
||||||
|
|
||||||
messages: List[BaseMessage]
|
messages: List[BaseMessage]
|
||||||
"""List of messages."""
|
"""List of messages."""
|
||||||
@ -258,12 +261,65 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||||
"""Chat prompt template. This is a prompt that is sent to the user."""
|
"""Use to create flexible templated prompts for chat models.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Instantiation from role strings:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
prompt_template = ChatPromptTemplate.from_role_strings(
|
||||||
|
[
|
||||||
|
('system', "You are a helpful bot. Your name is {bot_name}."),
|
||||||
|
('human', "{user_input}")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_template.format_messages(
|
||||||
|
bot_name="bobby",
|
||||||
|
user_input="Hello! What is your name?"
|
||||||
|
)
|
||||||
|
|
||||||
|
Instantiation from messages:
|
||||||
|
|
||||||
|
This is useful if it's important to distinguish between messages that
|
||||||
|
are templates and messages that are already formatted.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.prompts import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
SystemMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.schema import AIMessage
|
||||||
|
|
||||||
|
prompt_template = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessagePromptTemplate.from_template(
|
||||||
|
"You are a helpful bot. Your name is {bot_name}."
|
||||||
|
),
|
||||||
|
AIMessage(content="Hello!"), # Already formatted message
|
||||||
|
HumanMessagePromptTemplate.from_template(
|
||||||
|
"{user_input}"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_template.format_messages(
|
||||||
|
bot_name="bobby",
|
||||||
|
user_input="Hello! What is your name?"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
"""List of input variables."""
|
"""List of input variables."""
|
||||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||||
"""List of messages."""
|
"""List of messages consisting of either message prompt templates or messages."""
|
||||||
|
|
||||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||||
# Allow for easy combining
|
# Allow for easy combining
|
||||||
@ -279,9 +335,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_input_variables(cls, values: dict) -> dict:
|
def validate_input_variables(cls, values: dict) -> dict:
|
||||||
"""
|
"""Validate input variables.
|
||||||
Validate input variables. If input_variables is not set, it will be set to
|
|
||||||
the union of all input variables in the messages.
|
If input_variables is not set, it will be set to the union of
|
||||||
|
all input variables in the messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
values: values to validate.
|
values: values to validate.
|
||||||
@ -309,10 +366,13 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
|
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
|
||||||
"""Create a class from a template.
|
"""Create a chat prompt template from a template string.
|
||||||
|
|
||||||
|
Creates a chat template consisting of a single message assumed to be from
|
||||||
|
the human.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
template: template string.
|
template: template string
|
||||||
**kwargs: keyword arguments to pass to the constructor.
|
**kwargs: keyword arguments to pass to the constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -328,31 +388,41 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
"""Create a class from a list of (role, template) tuples.
|
"""Create a class from a list of (role, template) tuples.
|
||||||
|
|
||||||
|
The roles "human", "ai", and "system" are special and will be converted
|
||||||
|
to the appropriate message class. All other roles will be converted to a
|
||||||
|
generic ChatMessagePromptTemplate.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
string_messages: list of (role, template) tuples.
|
string_messages: list of (role, template) tuples.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new instance of this class.
|
a chat prompt template
|
||||||
"""
|
"""
|
||||||
messages = [
|
messages: List[BaseMessagePromptTemplate] = []
|
||||||
ChatMessagePromptTemplate(
|
message: BaseMessagePromptTemplate
|
||||||
prompt=PromptTemplate.from_template(template), role=role
|
for role, template in string_messages:
|
||||||
)
|
if role == "human":
|
||||||
for role, template in string_messages
|
message = HumanMessagePromptTemplate.from_template(template)
|
||||||
]
|
elif role == "ai":
|
||||||
|
message = AIMessagePromptTemplate.from_template(template)
|
||||||
|
elif role == "system":
|
||||||
|
message = SystemMessagePromptTemplate.from_template(template)
|
||||||
|
else:
|
||||||
|
message = ChatMessagePromptTemplate.from_template(template, role=role)
|
||||||
|
messages.append(message)
|
||||||
return cls.from_messages(messages)
|
return cls.from_messages(messages)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_strings(
|
def from_strings(
|
||||||
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
|
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
|
||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
"""Create a class from a list of (role, template) tuples.
|
"""Create a class from a list of (role class, template) tuples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
string_messages: list of (role, template) tuples.
|
string_messages: list of (role class, template) tuples.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new instance of this class.
|
a chat prompt template
|
||||||
"""
|
"""
|
||||||
messages = [
|
messages = [
|
||||||
role(prompt=PromptTemplate.from_template(template))
|
role(prompt=PromptTemplate.from_template(template))
|
||||||
@ -364,14 +434,13 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
def from_messages(
|
def from_messages(
|
||||||
cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]]
|
cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
"""
|
"""Create a chat template from a sequence of messages.
|
||||||
Create a class from a list of messages.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: list of messages.
|
messages: sequence of templated or regular messages
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new instance of this class.
|
a chat prompt template
|
||||||
"""
|
"""
|
||||||
input_vars = set()
|
input_vars = set()
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@ -380,17 +449,26 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
return cls(input_variables=list(input_vars), messages=messages)
|
return cls(input_variables=list(input_vars), messages=messages)
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
|
"""Format the chat template into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: keyword arguments to use for filling in template variables
|
||||||
|
in all the template messages in this chat template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
formatted string
|
||||||
|
"""
|
||||||
return self.format_prompt(**kwargs).to_string()
|
return self.format_prompt(**kwargs).to_string()
|
||||||
|
|
||||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
"""
|
"""Format the chat template into a list of finalized messages.
|
||||||
Format kwargs into a list of messages.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: keyword arguments to use for formatting.
|
**kwargs: keyword arguments to use for filling in template variables
|
||||||
|
in all the template messages in this chat template.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of messages.
|
list of formatted messages
|
||||||
"""
|
"""
|
||||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||||
result = []
|
result = []
|
||||||
@ -414,11 +492,11 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _prompt_type(self) -> str:
|
def _prompt_type(self) -> str:
|
||||||
|
"""Name of prompt type."""
|
||||||
return "chat"
|
return "chat"
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
"""
|
"""Save prompt to file.
|
||||||
Save prompt to file.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: path to file.
|
file_path: path to file.
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -7,13 +7,19 @@ from langchain.prompts import PromptTemplate
|
|||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
BaseMessagePromptTemplate,
|
BaseMessagePromptTemplate,
|
||||||
|
ChatMessage,
|
||||||
ChatMessagePromptTemplate,
|
ChatMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
ChatPromptValue,
|
ChatPromptValue,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema.messages import HumanMessage
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_messages() -> List[BaseMessagePromptTemplate]:
|
def create_messages() -> List[BaseMessagePromptTemplate]:
|
||||||
@ -133,7 +139,9 @@ def test_chat_prompt_template_from_messages() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_template_with_messages() -> None:
|
def test_chat_prompt_template_with_messages() -> None:
|
||||||
messages = create_messages() + [HumanMessage(content="foo")]
|
messages: List[
|
||||||
|
Union[BaseMessagePromptTemplate, BaseMessage]
|
||||||
|
] = create_messages() + [HumanMessage(content="foo")]
|
||||||
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
|
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
|
||||||
assert sorted(chat_prompt_template.input_variables) == sorted(
|
assert sorted(chat_prompt_template.input_variables) == sorted(
|
||||||
["context", "foo", "bar"]
|
["context", "foo", "bar"]
|
||||||
@ -175,7 +183,7 @@ def test_chat_valid_with_partial_variables() -> None:
|
|||||||
input_variables=["question", "context"],
|
input_variables=["question", "context"],
|
||||||
partial_variables={"formatins": "some structure"},
|
partial_variables={"formatins": "some structure"},
|
||||||
)
|
)
|
||||||
assert set(prompt.input_variables) == set(["question", "context"])
|
assert set(prompt.input_variables) == {"question", "context"}
|
||||||
assert prompt.partial_variables == {"formatins": "some structure"}
|
assert prompt.partial_variables == {"formatins": "some structure"}
|
||||||
|
|
||||||
|
|
||||||
@ -188,5 +196,25 @@ def test_chat_valid_infer_variables() -> None:
|
|||||||
prompt = ChatPromptTemplate(
|
prompt = ChatPromptTemplate(
|
||||||
messages=messages, partial_variables={"formatins": "some structure"}
|
messages=messages, partial_variables={"formatins": "some structure"}
|
||||||
)
|
)
|
||||||
assert set(prompt.input_variables) == set(["question", "context"])
|
assert set(prompt.input_variables) == {"question", "context"}
|
||||||
assert prompt.partial_variables == {"formatins": "some structure"}
|
assert prompt.partial_variables == {"formatins": "some structure"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_from_role_strings() -> None:
|
||||||
|
"""Test instantiation of chat template from role strings."""
|
||||||
|
template = ChatPromptTemplate.from_role_strings(
|
||||||
|
[
|
||||||
|
("system", "You are a bot."),
|
||||||
|
("ai", "hello!"),
|
||||||
|
("human", "{question}"),
|
||||||
|
("other", "{quack}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = template.format_messages(question="How are you?", quack="duck")
|
||||||
|
assert messages == [
|
||||||
|
SystemMessage(content="You are a bot."),
|
||||||
|
AIMessage(content="hello!"),
|
||||||
|
HumanMessage(content="How are you?"),
|
||||||
|
ChatMessage(content="duck", role="other"),
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user