core[patch]: simple prompt pretty printing (#15968)

This commit is contained in:
Bagatur 2024-01-12 21:08:51 -05:00 committed by GitHub
parent 3f75fd41cc
commit bccb07f93e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 84 additions and 3 deletions

View File

@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.utils import get_bolded_text
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
@ -42,6 +44,14 @@ class BaseMessage(Serializable):
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
def pretty_repr(self, html: bool = False) -> str:
title = get_msg_title_repr(self.type.title() + " Message", bold=html)
# TODO: handle non-string content.
return f"{title}\n\n{self.content}"
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],
@ -176,3 +186,13 @@ def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
List of messages as dicts.
"""
return [message_to_dict(m) for m in messages]
def get_msg_title_repr(title: str, *, bold: bool = False) -> str:
padded = " " + title + " "
sep_len = (80 - len(padded)) // 2
sep = "=" * sep_len
second_sep = sep + "=" if len(padded) % 2 else sep
if bold:
padded = get_bolded_text(padded)
return f"{sep}{padded}{second_sep}"

View File

@ -28,11 +28,14 @@ from langchain_core.messages import (
HumanMessage,
SystemMessage,
)
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
class BaseMessagePromptTemplate(Serializable, ABC):
@ -68,6 +71,13 @@ class BaseMessagePromptTemplate(Serializable, ABC):
List of input variables.
"""
def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation."""
raise NotImplementedError
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
@ -95,9 +105,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
return ["langchain", "prompts", "chat"]
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
return super().__init__(
variable_name=variable_name, optional=optional, **kwargs
)
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
@ -135,6 +143,15 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
"""
return [self.variable_name] if not self.optional else []
def pretty_repr(self, html: bool = False) -> str:
var = "{" + self.variable_name + "}"
if html:
title = get_msg_title_repr("Messages Placeholder", bold=True)
var = get_colored_text(var, "yellow")
else:
title = get_msg_title_repr("Messages Placeholder")
return f"{title}\n\n{var}"
MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
@ -237,6 +254,12 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""
return self.prompt.input_variables
def pretty_repr(self, html: bool = False) -> str:
# TODO: Handle partials
title = self.__class__.__name__.replace("MessagePromptTemplate", " Message")
title = get_msg_title_repr(title, bold=html)
return f"{title}\n\n{self.prompt.pretty_repr(html=html)}"
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Chat message prompt template."""
@ -369,6 +392,13 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages."""
def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation."""
raise NotImplementedError
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
@ -701,6 +731,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""
raise NotImplementedError()
def pretty_repr(self, html: bool = False) -> str:
# TODO: handle partials
return "\n\n".join(msg.pretty_repr(html=html) for msg in self.messages)
def _create_template_from_message_type(
message_type: str, template: str

View File

@ -340,3 +340,6 @@ class FewShotChatMessagePromptTemplate(
"""
messages = self.format_messages(**kwargs)
return get_buffer_string(messages)
def pretty_repr(self, html: bool = False) -> str:
raise NotImplementedError()

View File

@ -8,7 +8,9 @@ from typing import Any, Callable, Dict, List, Set
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env
def jinja2_formatter(template: str, **kwargs: Any) -> str:
@ -159,3 +161,17 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))
def pretty_repr(self, html: bool = False) -> str:
# TODO: handle partials
dummy_vars = {
input_var: "{" + f"{input_var}" + "}" for input_var in self.input_variables
}
if html:
dummy_vars = {
k: get_colored_text(v, "yellow") for k, v in dummy_vars.items()
}
return self.format(**dummy_vars)
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))

View File

@ -0,0 +1,5 @@
def is_interactive_env() -> bool:
"""Determine if running within IPython or Jupyter."""
import sys
return hasattr(sys, "ps2")

View File

@ -101,3 +101,6 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
messages += historical_messages
messages.append(input_message)
return messages
def pretty_repr(self, html: bool = False) -> str:
raise NotImplementedError