mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
core[patch]: simple prompt pretty printing (#15968)
This commit is contained in:
parent
3f75fd41cc
commit
bccb07f93e
@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
@ -42,6 +44,14 @@ class BaseMessage(Serializable):
|
||||
prompt = ChatPromptTemplate(messages=[self])
|
||||
return prompt + other
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
title = get_msg_title_repr(self.type.title() + " Message", bold=html)
|
||||
# TODO: handle non-string content.
|
||||
return f"{title}\n\n{self.content}"
|
||||
|
||||
def pretty_print(self) -> None:
|
||||
print(self.pretty_repr(html=is_interactive_env()))
|
||||
|
||||
|
||||
def merge_content(
|
||||
first_content: Union[str, List[Union[str, Dict]]],
|
||||
@ -176,3 +186,13 @@ def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||
List of messages as dicts.
|
||||
"""
|
||||
return [message_to_dict(m) for m in messages]
|
||||
|
||||
|
||||
def get_msg_title_repr(title: str, *, bold: bool = False) -> str:
|
||||
padded = " " + title + " "
|
||||
sep_len = (80 - len(padded)) // 2
|
||||
sep = "=" * sep_len
|
||||
second_sep = sep + "=" if len(padded) % 2 else sep
|
||||
if bold:
|
||||
padded = get_bolded_text(padded)
|
||||
return f"{sep}{padded}{second_sep}"
|
||||
|
@ -28,11 +28,14 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.messages.base import get_msg_title_repr
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
|
||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
@ -68,6 +71,13 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
List of input variables.
|
||||
"""
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
"""Human-readable representation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def pretty_print(self) -> None:
|
||||
print(self.pretty_repr(html=is_interactive_env()))
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
"""Combine two prompt templates.
|
||||
|
||||
@ -95,9 +105,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
|
||||
return super().__init__(
|
||||
variable_name=variable_name, optional=optional, **kwargs
|
||||
)
|
||||
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
@ -135,6 +143,15 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""
|
||||
return [self.variable_name] if not self.optional else []
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
var = "{" + self.variable_name + "}"
|
||||
if html:
|
||||
title = get_msg_title_repr("Messages Placeholder", bold=True)
|
||||
var = get_colored_text(var, "yellow")
|
||||
else:
|
||||
title = get_msg_title_repr("Messages Placeholder")
|
||||
return f"{title}\n\n{var}"
|
||||
|
||||
|
||||
MessagePromptTemplateT = TypeVar(
|
||||
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
|
||||
@ -237,6 +254,12 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
# TODO: Handle partials
|
||||
title = self.__class__.__name__.replace("MessagePromptTemplate", " Message")
|
||||
title = get_msg_title_repr(title, bold=html)
|
||||
return f"{title}\n\n{self.prompt.pretty_repr(html=html)}"
|
||||
|
||||
|
||||
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Chat message prompt template."""
|
||||
@ -369,6 +392,13 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages."""
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
"""Human-readable representation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def pretty_print(self) -> None:
|
||||
print(self.pretty_repr(html=is_interactive_env()))
|
||||
|
||||
|
||||
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
||||
|
||||
@ -701,6 +731,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
# TODO: handle partials
|
||||
return "\n\n".join(msg.pretty_repr(html=html) for msg in self.messages)
|
||||
|
||||
|
||||
def _create_template_from_message_type(
|
||||
message_type: str, template: str
|
||||
|
@ -340,3 +340,6 @@ class FewShotChatMessagePromptTemplate(
|
||||
"""
|
||||
messages = self.format_messages(**kwargs)
|
||||
return get_buffer_string(messages)
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
raise NotImplementedError()
|
||||
|
@ -8,7 +8,9 @@ from typing import Any, Callable, Dict, List, Set
|
||||
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.formatting import formatter
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
@ -159,3 +161,17 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
return StringPromptValue(text=self.format(**kwargs))
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
# TODO: handle partials
|
||||
dummy_vars = {
|
||||
input_var: "{" + f"{input_var}" + "}" for input_var in self.input_variables
|
||||
}
|
||||
if html:
|
||||
dummy_vars = {
|
||||
k: get_colored_text(v, "yellow") for k, v in dummy_vars.items()
|
||||
}
|
||||
return self.format(**dummy_vars)
|
||||
|
||||
def pretty_print(self) -> None:
|
||||
print(self.pretty_repr(html=is_interactive_env()))
|
||||
|
5
libs/core/langchain_core/utils/interactive_env.py
Normal file
5
libs/core/langchain_core/utils/interactive_env.py
Normal file
@ -0,0 +1,5 @@
|
||||
def is_interactive_env() -> bool:
|
||||
"""Determine if running within IPython or Jupyter."""
|
||||
import sys
|
||||
|
||||
return hasattr(sys, "ps2")
|
@ -101,3 +101,6 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
|
||||
messages += historical_messages
|
||||
messages.append(input_message)
|
||||
return messages
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
raise NotImplementedError
|
||||
|
Loading…
Reference in New Issue
Block a user