Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
79d024d2ff core[minor]: Init ChatPromptTemplate with message-like representations 2024-01-31 15:40:37 -08:00

View File

@@ -593,6 +593,24 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
validate_template: bool = False
"""Whether or not to try validating the template."""
def __init__(
self,
messages: Sequence[MessageLikeRepresentation],
input_variables: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> None:
_messages = [_convert_to_message(msg) for msg in messages]
if input_variables is None:
prompts = [
msg
for msg in _messages
if isinstance(msg, (BaseChatPromptTemplate, BaseMessagePromptTemplate))
]
input_variables = list(
set(iv for prompt in prompts for iv in prompt.input_variables)
)
super().__init__(messages=_messages, input_variables=input_variables, **kwargs)
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
@@ -677,45 +695,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
message = HumanMessagePromptTemplate(prompt=prompt_template)
return cls.from_messages([message])
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_role_strings(
cls, string_messages: List[Tuple[str, str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role, template) tuples.
Args:
string_messages: list of (role, template) tuples.
Returns:
a chat prompt template
"""
return cls(
messages=[
ChatMessagePromptTemplate.from_template(template, role=role)
for role, template in string_messages
]
)
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role class, template) tuples.
Args:
string_messages: list of (role class, template) tuples.
Returns:
a chat prompt template
"""
return cls.from_messages(string_messages)
@classmethod
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
cls, messages: Sequence[MessageLikeRepresentation], **kwargs: Any
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
@@ -747,6 +729,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
(message type, template); e.g., ("human", "{user_input}"),
(4) 2-tuple of (message class, template), (4) a string which is
shorthand for ("human", template); e.g., "{user_input}"
**kwargs: Additional keyword arguments to pass to ChatPromptTemplate
constructor.
Returns:
a chat prompt template
@@ -754,14 +738,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
_messages = [_convert_to_message(message) for message in messages]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
input_vars: Set[str] = kwargs.pop("input_variables", set())
for _message in _messages:
if isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(input_variables=sorted(input_vars), messages=_messages)
return cls(input_variables=sorted(input_vars), messages=_messages, **kwargs)
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
@@ -887,6 +871,41 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
# TODO: handle partials
return "\n\n".join(msg.pretty_repr(html=html) for msg in self.messages)
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_role_strings(
cls, string_messages: List[Tuple[str, str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role, template) tuples.
Args:
string_messages: list of (role, template) tuples.
Returns:
a chat prompt template
"""
return cls(
messages=[
ChatMessagePromptTemplate.from_template(template, role=role)
for role, template in string_messages
]
)
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role class, template) tuples.
Args:
string_messages: list of (role class, template) tuples.
Returns:
a chat prompt template
"""
return cls.from_messages(string_messages)
def _create_template_from_message_type(
message_type: str, template: Union[str, list]