mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
core[patch]: simple prompt pretty printing (#15968)
This commit is contained in:
parent
3f75fd41cc
commit
bccb07f93e
@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
|||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.pydantic_v1 import Extra, Field
|
from langchain_core.pydantic_v1 import Extra, Field
|
||||||
|
from langchain_core.utils import get_bolded_text
|
||||||
|
from langchain_core.utils.interactive_env import is_interactive_env
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
@ -42,6 +44,14 @@ class BaseMessage(Serializable):
|
|||||||
prompt = ChatPromptTemplate(messages=[self])
|
prompt = ChatPromptTemplate(messages=[self])
|
||||||
return prompt + other
|
return prompt + other
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
title = get_msg_title_repr(self.type.title() + " Message", bold=html)
|
||||||
|
# TODO: handle non-string content.
|
||||||
|
return f"{title}\n\n{self.content}"
|
||||||
|
|
||||||
|
def pretty_print(self) -> None:
|
||||||
|
print(self.pretty_repr(html=is_interactive_env()))
|
||||||
|
|
||||||
|
|
||||||
def merge_content(
|
def merge_content(
|
||||||
first_content: Union[str, List[Union[str, Dict]]],
|
first_content: Union[str, List[Union[str, Dict]]],
|
||||||
@ -176,3 +186,13 @@ def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
|||||||
List of messages as dicts.
|
List of messages as dicts.
|
||||||
"""
|
"""
|
||||||
return [message_to_dict(m) for m in messages]
|
return [message_to_dict(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def get_msg_title_repr(title: str, *, bold: bool = False) -> str:
|
||||||
|
padded = " " + title + " "
|
||||||
|
sep_len = (80 - len(padded)) // 2
|
||||||
|
sep = "=" * sep_len
|
||||||
|
second_sep = sep + "=" if len(padded) % 2 else sep
|
||||||
|
if bold:
|
||||||
|
padded = get_bolded_text(padded)
|
||||||
|
return f"{sep}{padded}{second_sep}"
|
||||||
|
@ -28,11 +28,14 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.base import get_msg_title_repr
|
||||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.prompts.string import StringPromptTemplate
|
from langchain_core.prompts.string import StringPromptTemplate
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
|
from langchain_core.utils import get_colored_text
|
||||||
|
from langchain_core.utils.interactive_env import is_interactive_env
|
||||||
|
|
||||||
|
|
||||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||||
@ -68,6 +71,13 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
|||||||
List of input variables.
|
List of input variables.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
"""Human-readable representation."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def pretty_print(self) -> None:
|
||||||
|
print(self.pretty_repr(html=is_interactive_env()))
|
||||||
|
|
||||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||||
"""Combine two prompt templates.
|
"""Combine two prompt templates.
|
||||||
|
|
||||||
@ -95,9 +105,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
return ["langchain", "prompts", "chat"]
|
return ["langchain", "prompts", "chat"]
|
||||||
|
|
||||||
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
|
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
|
||||||
return super().__init__(
|
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
|
||||||
variable_name=variable_name, optional=optional, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
"""Format messages from kwargs.
|
"""Format messages from kwargs.
|
||||||
@ -135,6 +143,15 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
"""
|
"""
|
||||||
return [self.variable_name] if not self.optional else []
|
return [self.variable_name] if not self.optional else []
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
var = "{" + self.variable_name + "}"
|
||||||
|
if html:
|
||||||
|
title = get_msg_title_repr("Messages Placeholder", bold=True)
|
||||||
|
var = get_colored_text(var, "yellow")
|
||||||
|
else:
|
||||||
|
title = get_msg_title_repr("Messages Placeholder")
|
||||||
|
return f"{title}\n\n{var}"
|
||||||
|
|
||||||
|
|
||||||
MessagePromptTemplateT = TypeVar(
|
MessagePromptTemplateT = TypeVar(
|
||||||
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
|
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
|
||||||
@ -237,6 +254,12 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
|||||||
"""
|
"""
|
||||||
return self.prompt.input_variables
|
return self.prompt.input_variables
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
# TODO: Handle partials
|
||||||
|
title = self.__class__.__name__.replace("MessagePromptTemplate", " Message")
|
||||||
|
title = get_msg_title_repr(title, bold=html)
|
||||||
|
return f"{title}\n\n{self.prompt.pretty_repr(html=html)}"
|
||||||
|
|
||||||
|
|
||||||
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||||
"""Chat message prompt template."""
|
"""Chat message prompt template."""
|
||||||
@ -369,6 +392,13 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
"""Format kwargs into a list of messages."""
|
"""Format kwargs into a list of messages."""
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
"""Human-readable representation."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def pretty_print(self) -> None:
|
||||||
|
print(self.pretty_repr(html=is_interactive_env()))
|
||||||
|
|
||||||
|
|
||||||
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
||||||
|
|
||||||
@ -701,6 +731,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
# TODO: handle partials
|
||||||
|
return "\n\n".join(msg.pretty_repr(html=html) for msg in self.messages)
|
||||||
|
|
||||||
|
|
||||||
def _create_template_from_message_type(
|
def _create_template_from_message_type(
|
||||||
message_type: str, template: str
|
message_type: str, template: str
|
||||||
|
@ -340,3 +340,6 @@ class FewShotChatMessagePromptTemplate(
|
|||||||
"""
|
"""
|
||||||
messages = self.format_messages(**kwargs)
|
messages = self.format_messages(**kwargs)
|
||||||
return get_buffer_string(messages)
|
return get_buffer_string(messages)
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
@ -8,7 +8,9 @@ from typing import Any, Callable, Dict, List, Set
|
|||||||
|
|
||||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langchain_core.utils import get_colored_text
|
||||||
from langchain_core.utils.formatting import formatter
|
from langchain_core.utils.formatting import formatter
|
||||||
|
from langchain_core.utils.interactive_env import is_interactive_env
|
||||||
|
|
||||||
|
|
||||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||||
@ -159,3 +161,17 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
"""Create Chat Messages."""
|
"""Create Chat Messages."""
|
||||||
return StringPromptValue(text=self.format(**kwargs))
|
return StringPromptValue(text=self.format(**kwargs))
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
# TODO: handle partials
|
||||||
|
dummy_vars = {
|
||||||
|
input_var: "{" + f"{input_var}" + "}" for input_var in self.input_variables
|
||||||
|
}
|
||||||
|
if html:
|
||||||
|
dummy_vars = {
|
||||||
|
k: get_colored_text(v, "yellow") for k, v in dummy_vars.items()
|
||||||
|
}
|
||||||
|
return self.format(**dummy_vars)
|
||||||
|
|
||||||
|
def pretty_print(self) -> None:
|
||||||
|
print(self.pretty_repr(html=is_interactive_env()))
|
||||||
|
5
libs/core/langchain_core/utils/interactive_env.py
Normal file
5
libs/core/langchain_core/utils/interactive_env.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
def is_interactive_env() -> bool:
|
||||||
|
"""Determine if running within IPython or Jupyter."""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
return hasattr(sys, "ps2")
|
@ -101,3 +101,6 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
|
|||||||
messages += historical_messages
|
messages += historical_messages
|
||||||
messages.append(input_message)
|
messages.append(input_message)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
Loading…
Reference in New Issue
Block a user