mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 05:45:01 +00:00
Support a few list like operations on ChatPromptTemplate (#9077)
Make it easier to work with chat prompt template
This commit is contained in:
parent
e4418d1b7e
commit
44bc89b7bf
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, overload
|
||||||
|
|
||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
@ -317,6 +317,16 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
"""Format kwargs into a list of messages."""
|
"""Format kwargs into a list of messages."""
|
||||||
|
|
||||||
|
|
||||||
|
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
||||||
|
|
||||||
|
MessageLikeRepresentation = Union[
|
||||||
|
MessageLike,
|
||||||
|
Tuple[str, str],
|
||||||
|
Tuple[Type, str],
|
||||||
|
str,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||||
"""A prompt template for chat models.
|
"""A prompt template for chat models.
|
||||||
|
|
||||||
@ -343,9 +353,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
"""List of input variables in template messages. Used for validation."""
|
"""List of input variables in template messages. Used for validation."""
|
||||||
messages: List[
|
messages: List[MessageLike]
|
||||||
Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
|
||||||
]
|
|
||||||
"""List of messages consisting of either message prompt templates or messages."""
|
"""List of messages consisting of either message prompt templates or messages."""
|
||||||
|
|
||||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||||
@ -364,6 +372,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
||||||
):
|
):
|
||||||
return ChatPromptTemplate(messages=self.messages + [other])
|
return ChatPromptTemplate(messages=self.messages + [other])
|
||||||
|
elif isinstance(other, (list, tuple)):
|
||||||
|
_other = ChatPromptTemplate.from_messages(other)
|
||||||
|
return ChatPromptTemplate(messages=self.messages + _other.messages)
|
||||||
elif isinstance(other, str):
|
elif isinstance(other, str):
|
||||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||||
return ChatPromptTemplate(messages=self.messages + [prompt])
|
return ChatPromptTemplate(messages=self.messages + [prompt])
|
||||||
@ -457,16 +468,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_messages(
|
def from_messages(
|
||||||
cls,
|
cls,
|
||||||
messages: Sequence[
|
messages: Sequence[MessageLikeRepresentation],
|
||||||
Union[
|
|
||||||
BaseMessagePromptTemplate,
|
|
||||||
BaseChatPromptTemplate,
|
|
||||||
BaseMessage,
|
|
||||||
Tuple[str, str],
|
|
||||||
Tuple[Type, str],
|
|
||||||
str,
|
|
||||||
]
|
|
||||||
],
|
|
||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
"""Create a chat prompt template from a variety of message formats.
|
"""Create a chat prompt template from a variety of message formats.
|
||||||
|
|
||||||
@ -556,8 +558,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
|
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
|
||||||
"""Return a new ChatPromptTemplate with some of the input variables already
|
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
||||||
filled in.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: keyword arguments to use for filling in template variables. Ought
|
**kwargs: keyword arguments to use for filling in template variables. Ought
|
||||||
@ -592,6 +593,41 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||||
return type(self)(**prompt_dict)
|
return type(self)(**prompt_dict)
|
||||||
|
|
||||||
|
def append(self, message: MessageLikeRepresentation) -> None:
|
||||||
|
"""Append message to the end of the chat template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: representation of a message to append.
|
||||||
|
"""
|
||||||
|
self.messages.append(_convert_to_message(message))
|
||||||
|
|
||||||
|
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
|
||||||
|
"""Extend the chat template with a sequence of messages."""
|
||||||
|
self.messages.extend([_convert_to_message(message) for message in messages])
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, index: int) -> MessageLike:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, index: slice) -> ChatPromptTemplate:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __getitem__(
|
||||||
|
self, index: Union[int, slice]
|
||||||
|
) -> Union[MessageLike, ChatPromptTemplate]:
|
||||||
|
"""Use to index into the chat template."""
|
||||||
|
if isinstance(index, slice):
|
||||||
|
start, stop, step = index.indices(len(self.messages))
|
||||||
|
messages = self.messages[start:stop:step]
|
||||||
|
return ChatPromptTemplate.from_messages(messages)
|
||||||
|
else:
|
||||||
|
return self.messages[index]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Get the length of the chat template."""
|
||||||
|
return len(self.messages)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _prompt_type(self) -> str:
|
def _prompt_type(self) -> str:
|
||||||
"""Name of prompt type."""
|
"""Name of prompt type."""
|
||||||
@ -635,14 +671,7 @@ def _create_template_from_message_type(
|
|||||||
|
|
||||||
|
|
||||||
def _convert_to_message(
|
def _convert_to_message(
|
||||||
message: Union[
|
message: MessageLikeRepresentation,
|
||||||
BaseMessagePromptTemplate,
|
|
||||||
BaseChatPromptTemplate,
|
|
||||||
BaseMessage,
|
|
||||||
Tuple[str, str],
|
|
||||||
Tuple[Type, str],
|
|
||||||
str,
|
|
||||||
]
|
|
||||||
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
||||||
"""Instantiate a message from a variety of message formats.
|
"""Instantiate a message from a variety of message formats.
|
||||||
|
|
||||||
|
@ -282,6 +282,42 @@ def test_convert_to_message(
|
|||||||
assert _convert_to_message(args) == expected
|
assert _convert_to_message(args) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_prompt_template_indexing() -> None:
|
||||||
|
message1 = SystemMessage(content="foo")
|
||||||
|
message2 = HumanMessage(content="bar")
|
||||||
|
message3 = HumanMessage(content="baz")
|
||||||
|
template = ChatPromptTemplate.from_messages([message1, message2, message3])
|
||||||
|
assert template[0] == message1
|
||||||
|
assert template[1] == message2
|
||||||
|
|
||||||
|
# Slice starting from index 1
|
||||||
|
slice_template = template[1:]
|
||||||
|
assert slice_template[0] == message2
|
||||||
|
assert len(slice_template) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_prompt_template_append_and_extend() -> None:
|
||||||
|
"""Test append and extend methods of ChatPromptTemplate."""
|
||||||
|
message1 = SystemMessage(content="foo")
|
||||||
|
message2 = HumanMessage(content="bar")
|
||||||
|
message3 = HumanMessage(content="baz")
|
||||||
|
template = ChatPromptTemplate.from_messages([message1])
|
||||||
|
template.append(message2)
|
||||||
|
template.append(message3)
|
||||||
|
assert len(template) == 3
|
||||||
|
template.extend([message2, message3])
|
||||||
|
assert len(template) == 5
|
||||||
|
assert template.messages == [
|
||||||
|
message1,
|
||||||
|
message2,
|
||||||
|
message3,
|
||||||
|
message2,
|
||||||
|
message3,
|
||||||
|
]
|
||||||
|
template.append(("system", "hello!"))
|
||||||
|
assert template[-1] == SystemMessagePromptTemplate.from_template("hello!")
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_message_is_strict() -> None:
|
def test_convert_to_message_is_strict() -> None:
|
||||||
"""Verify that _convert_to_message is strict."""
|
"""Verify that _convert_to_message is strict."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
Loading…
Reference in New Issue
Block a user