From bccb07f93ea2e3e2c3eb000e7ed3d589d3be5f3b Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 12 Jan 2024 21:08:51 -0500 Subject: [PATCH] core[patch]: simple prompt pretty printing (#15968) --- libs/core/langchain_core/messages/base.py | 20 ++++++++++ libs/core/langchain_core/prompts/chat.py | 40 +++++++++++++++++-- libs/core/langchain_core/prompts/few_shot.py | 3 ++ libs/core/langchain_core/prompts/string.py | 16 ++++++++ .../langchain_core/utils/interactive_env.py | 5 +++ .../autonomous_agents/autogpt/prompt.py | 3 ++ 6 files changed, 84 insertions(+), 3 deletions(-) create mode 100644 libs/core/langchain_core/utils/interactive_env.py diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 2076d34060e..b2fd76b6592 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union from langchain_core.load.serializable import Serializable from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.utils import get_bolded_text +from langchain_core.utils.interactive_env import is_interactive_env if TYPE_CHECKING: from langchain_core.prompts.chat import ChatPromptTemplate @@ -42,6 +44,14 @@ class BaseMessage(Serializable): prompt = ChatPromptTemplate(messages=[self]) return prompt + other + def pretty_repr(self, html: bool = False) -> str: + title = get_msg_title_repr(self.type.title() + " Message", bold=html) + # TODO: handle non-string content. + return f"{title}\n\n{self.content}" + + def pretty_print(self) -> None: + print(self.pretty_repr(html=is_interactive_env())) + def merge_content( first_content: Union[str, List[Union[str, Dict]]], @@ -176,3 +186,13 @@ def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]: List of messages as dicts. """ return [message_to_dict(m) for m in messages] + + +def get_msg_title_repr(title: str, *, bold: bool = False) -> str: + padded = " " + title + " " + sep_len = (80 - len(padded)) // 2 + sep = "=" * sep_len + second_sep = sep + "=" if len(padded) % 2 else sep + if bold: + padded = get_bolded_text(padded) + return f"{sep}{padded}{second_sep}" diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 94ce5e6040d..f0de1aa88b5 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -28,11 +28,14 @@ from langchain_core.messages import ( HumanMessage, SystemMessage, ) +from langchain_core.messages.base import get_msg_title_repr from langchain_core.prompt_values import ChatPromptValue, PromptValue from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import StringPromptTemplate from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.utils import get_colored_text +from langchain_core.utils.interactive_env import is_interactive_env class BaseMessagePromptTemplate(Serializable, ABC): @@ -68,6 +71,13 @@ class BaseMessagePromptTemplate(Serializable, ABC): List of input variables. """ + def pretty_repr(self, html: bool = False) -> str: + """Human-readable representation.""" + raise NotImplementedError + + def pretty_print(self) -> None: + print(self.pretty_repr(html=is_interactive_env())) + def __add__(self, other: Any) -> ChatPromptTemplate: """Combine two prompt templates. @@ -95,9 +105,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): return ["langchain", "prompts", "chat"] def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any): - return super().__init__( - variable_name=variable_name, optional=optional, **kwargs - ) + super().__init__(variable_name=variable_name, optional=optional, **kwargs) def format_messages(self, **kwargs: Any) -> List[BaseMessage]: """Format messages from kwargs. @@ -135,6 +143,15 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): """ return [self.variable_name] if not self.optional else [] + def pretty_repr(self, html: bool = False) -> str: + var = "{" + self.variable_name + "}" + if html: + title = get_msg_title_repr("Messages Placeholder", bold=True) + var = get_colored_text(var, "yellow") + else: + title = get_msg_title_repr("Messages Placeholder") + return f"{title}\n\n{var}" + MessagePromptTemplateT = TypeVar( "MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate" @@ -237,6 +254,12 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): """ return self.prompt.input_variables + def pretty_repr(self, html: bool = False) -> str: + # TODO: Handle partials + title = self.__class__.__name__.replace("MessagePromptTemplate", " Message") + title = get_msg_title_repr(title, bold=html) + return f"{title}\n\n{self.prompt.pretty_repr(html=html)}" + class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): """Chat message prompt template.""" @@ -369,6 +392,13 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): def format_messages(self, **kwargs: Any) -> List[BaseMessage]: """Format kwargs into a list of messages.""" + def pretty_repr(self, html: bool = False) -> str: + """Human-readable representation.""" + raise NotImplementedError + + def pretty_print(self) -> None: + print(self.pretty_repr(html=is_interactive_env())) + MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate] @@ -701,6 +731,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate): """ raise NotImplementedError() + def pretty_repr(self, html: bool = False) -> str: + # TODO: handle partials + return "\n\n".join(msg.pretty_repr(html=html) for msg in self.messages) + def _create_template_from_message_type( message_type: str, template: str diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index 76df8669047..8be9d0efb61 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -340,3 +340,6 @@ class FewShotChatMessagePromptTemplate( """ messages = self.format_messages(**kwargs) return get_buffer_string(messages) + + def pretty_repr(self, html: bool = False) -> str: + raise NotImplementedError() diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 5981fe9988d..0923da9fbee 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -8,7 +8,9 @@ from typing import Any, Callable, Dict, List, Set from langchain_core.prompt_values import PromptValue, StringPromptValue from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.utils import get_colored_text from langchain_core.utils.formatting import formatter +from langchain_core.utils.interactive_env import is_interactive_env def jinja2_formatter(template: str, **kwargs: Any) -> str: @@ -159,3 +161,17 @@ class StringPromptTemplate(BasePromptTemplate, ABC): def format_prompt(self, **kwargs: Any) -> PromptValue: """Create Chat Messages.""" return StringPromptValue(text=self.format(**kwargs)) + + def pretty_repr(self, html: bool = False) -> str: + # TODO: handle partials + dummy_vars = { + input_var: "{" + f"{input_var}" + "}" for input_var in self.input_variables + } + if html: + dummy_vars = { + k: get_colored_text(v, "yellow") for k, v in dummy_vars.items() + } + return self.format(**dummy_vars) + + def pretty_print(self) -> None: + print(self.pretty_repr(html=is_interactive_env())) diff --git a/libs/core/langchain_core/utils/interactive_env.py b/libs/core/langchain_core/utils/interactive_env.py new file mode 100644 index 00000000000..7752b6b403c --- /dev/null +++ b/libs/core/langchain_core/utils/interactive_env.py @@ -0,0 +1,5 @@ +def is_interactive_env() -> bool: + """Determine if running within IPython or Jupyter.""" + import sys + + return hasattr(sys, "ps2") diff --git a/libs/experimental/langchain_experimental/autonomous_agents/autogpt/prompt.py b/libs/experimental/langchain_experimental/autonomous_agents/autogpt/prompt.py index 7a5b4c831b8..0155aa86df1 100644 --- a/libs/experimental/langchain_experimental/autonomous_agents/autogpt/prompt.py +++ b/libs/experimental/langchain_experimental/autonomous_agents/autogpt/prompt.py @@ -101,3 +101,6 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc] messages += historical_messages messages.append(input_message) return messages + + def pretty_repr(self, html: bool = False) -> str: + raise NotImplementedError