core: docstrings messages (#23788)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-07-03 08:25:00 -07:00 committed by GitHub
parent 54e730f6e4
commit 30fdc2dbe7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 183 additions and 42 deletions

View File

@ -73,17 +73,27 @@ class AIMessage(BaseMessage):
"""
type: Literal["ai"] = "ai"
"""The type of the message (used for deserialization)."""
"""The type of the message (used for deserialization). Defaults to "ai"."""
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg."""
"""Pass in content as positional arg.
Args:
content: The content of the message.
**kwargs: Additional arguments to pass to the parent class.
"""
super().__init__(content=content, **kwargs)
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Returns:
The namespace of the langchain object.
Defaults to ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
@property
@ -117,7 +127,15 @@ class AIMessage(BaseMessage):
return values
def pretty_repr(self, html: bool = False) -> str:
"""Return a pretty representation of the message."""
"""Return a pretty representation of the message.
Args:
html: Whether to return an HTML-formatted string.
Defaults to False.
Returns:
A pretty representation of the message.
"""
base = super().pretty_repr(html=html)
lines = []
@ -157,14 +175,21 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment]
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore
"""The type of the message (used for deserialization).
Defaults to "AIMessageChunk"."""
tool_call_chunks: List[ToolCallChunk] = []
"""If provided, tool call chunks associated with the message."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Returns:
The namespace of the langchain object.
Defaults to ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
@property
@ -177,6 +202,17 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
@root_validator(pre=False, skip_on_failure=True)
def init_tool_calls(cls, values: dict) -> dict:
"""Initialize tool calls from tool call chunks.
Args:
values: The values to validate.
Returns:
The values with tool calls initialized.
Raises:
ValueError: If the tool call chunks are malformed.
"""
if not values["tool_call_chunks"]:
if values["tool_calls"]:
values["tool_call_chunks"] = [

View File

@ -57,17 +57,29 @@ class BaseMessage(Serializable):
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg."""
"""Pass in content as positional arg.
Args:
content: The string contents of the message.
**kwargs: Additional fields to pass to the
"""
super().__init__(content=content, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
"""Return whether this class is serializable. This is used to determine
whether the class should be included in the langchain schema.
Returns:
True if the class is serializable, False otherwise.
"""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> ChatPromptTemplate:
@ -78,6 +90,15 @@ class BaseMessage(Serializable):
return prompt + other
def pretty_repr(self, html: bool = False) -> str:
"""Get a pretty representation of the message.
Args:
html: Whether to format the message as HTML. If True, the message will be
formatted with HTML tags. Default is False.
Returns:
A pretty representation of the message.
"""
title = get_msg_title_repr(self.type.title() + " Message", bold=html)
# TODO: handle non-string content.
if self.name is not None:
@ -95,8 +116,8 @@ def merge_content(
"""Merge two message contents.
Args:
first_content: The first content.
second_content: The second content.
first_content: The first content. Can be a string or a list.
second_content: The second content. Can be a string or a list.
Returns:
The merged content.
@ -133,7 +154,9 @@ class BaseMessageChunk(BaseMessage):
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
@ -142,6 +165,16 @@ class BaseMessageChunk(BaseMessage):
This functionality is useful to combine message chunks yielded from
a streaming model into a complete message.
Args:
other: Another message chunk to concatenate with this one.
Returns:
A new message chunk that is the concatenation of this message chunk
and the other message chunk.
Raises:
TypeError: If the other object is not a message chunk.
For example,
`AIMessageChunk(content="Hello") + AIMessageChunk(content=" World")`
@ -177,7 +210,8 @@ def message_to_dict(message: BaseMessage) -> dict:
message: Message to convert.
Returns:
Message as a dict.
Message as a dict. The dict will have a "type" key with the message type
and a "data" key with the message data as a dict.
"""
return {"type": message.type, "data": message.dict()}
@ -199,7 +233,7 @@ def get_msg_title_repr(title: str, *, bold: bool = False) -> str:
Args:
title: The title.
bold: Whether to bold the title.
bold: Whether to bold the title. Default is False.
Returns:
The title representation.

View File

@ -15,11 +15,13 @@ class ChatMessage(BaseMessage):
"""The speaker / role of the Message."""
type: Literal["chat"] = "chat"
"""The type of the message (used during serialization)."""
"""The type of the message (used during serialization). Defaults to "chat"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
@ -33,11 +35,14 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
"""The type of the message (used during serialization)."""
"""The type of the message (used during serialization).
Defaults to "ChatMessageChunk"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore

View File

@ -23,11 +23,12 @@ class FunctionMessage(BaseMessage):
"""The name of the function that was executed."""
type: Literal["function"] = "function"
"""The type of the message (used for serialization)."""
"""The type of the message (used for serialization). Defaults to "function"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
@ -41,10 +42,13 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
"""The type of the message (used for serialization).
Defaults to "FunctionMessageChunk"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore

View File

@ -32,19 +32,27 @@ class HumanMessage(BaseMessage):
"""Use to denote that a message is part of an example conversation.
At the moment, this is ignored by most models. Usage is discouraged.
Defaults to False.
"""
type: Literal["human"] = "human"
"""The type of the message (used for serialization). Defaults to "human"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg."""
"""Pass in content as positional arg.
Args:
content: The string contents of the message.
**kwargs: Additional fields to pass to the message.
"""
super().__init__(content=content, **kwargs)
@ -58,8 +66,11 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment]
"""The type of the message (used for serialization).
Defaults to "HumanMessageChunk"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]

View File

@ -9,8 +9,18 @@ class RemoveMessage(BaseMessage):
"""Message responsible for deleting other messages."""
type: Literal["remove"] = "remove"
"""The type of the message (used for serialization). Defaults to "remove"."""
def __init__(self, id: str, **kwargs: Any) -> None:
"""Create a RemoveMessage.
Args:
id: The ID of the message to remove.
**kwargs: Additional fields to pass to the message.
Raises:
ValueError: If the 'content' field is passed in kwargs.
"""
if kwargs.pop("content", None):
raise ValueError("RemoveMessage does not support 'content' field.")
@ -18,7 +28,8 @@ class RemoveMessage(BaseMessage):
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]

View File

@ -30,16 +30,23 @@ class SystemMessage(BaseMessage):
"""
type: Literal["system"] = "system"
"""The type of the message (used for serialization). Defaults to "system"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg."""
"""Pass in content as positional arg.
Args:
content: The string contents of the message.
**kwargs: Additional fields to pass to the message.
"""
super().__init__(content=content, **kwargs)
@ -53,8 +60,11 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment]
"""The type of the message (used for serialization).
Defaults to "SystemMessageChunk"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]

View File

@ -37,16 +37,23 @@ class ToolMessage(BaseMessage):
# """Whether the tool errored."""
type: Literal["tool"] = "tool"
"""The type of the message (used for serialization). Defaults to "tool"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg."""
"""Pass in content as positional arg.
Args:
content: The string contents of the message.
**kwargs: Additional fields to pass to the message
"""
super().__init__(content=content, **kwargs)

View File

@ -6,6 +6,7 @@ Some examples of what you can do with these functions include:
* Convert messages from dicts to Message objects (deserialization)
* Filter messages from a list of messages based on name, type or id etc.
"""
from __future__ import annotations
import inspect
@ -53,11 +54,15 @@ def get_buffer_string(
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Default is "Human".
ai_prefix: THe prefix to prepend to contents of AIMessages. Default is "AI".
Returns:
A single string concatenation of all input messages.
Raises:
ValueError: If an unsupported message type is encountered.
Example:
.. code-block:: python
@ -173,11 +178,20 @@ def _create_message_from_message_type(
"""Create a message from a message type and content string.
Args:
message_type: str the type of the message (e.g., "human", "ai", etc.)
content: str the content string.
message_type: (str) the type of the message (e.g., "human", "ai", etc.).
content: (str) the content string.
name: (str) the name of the message. Default is None.
tool_call_id: (str) the tool call id. Default is None.
tool_calls: (List[Dict[str, Any]]) the tool calls. Default is None.
id: (str) the id of the message. Default is None.
**additional_kwargs: (Dict[str, Any]) additional keyword arguments.
Returns:
a message of the appropriate type.
Raises:
ValueError: if the message type is not one of "human", "user", "ai",
"assistant", "system", "function", or "tool".
"""
kwargs: Dict[str, Any] = {}
if name is not None:
@ -203,7 +217,7 @@ def _create_message_from_message_type(
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system'."
f" 'user', 'ai', 'assistant', 'function', 'tool', or 'system'."
)
return message
@ -220,10 +234,14 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
message: a representation of a message in one of the supported formats.
Returns:
an instance of a message or a message template
an instance of a message or a message template.
Raises:
NotImplementedError: if the message type is not supported.
ValueError: if the message dict does not contain the required keys.
"""
if isinstance(message, BaseMessage):
_message = message
@ -315,16 +333,16 @@ def filter_messages(
Args:
messages: Sequence Message-like objects to filter.
include_names: Message names to include.
exclude_names: Messages names to exclude.
include_names: Message names to include. Default is None.
exclude_names: Messages names to exclude. Default is None.
include_types: Message types to include. Can be specified as string names (e.g.
"system", "human", "ai", ...) or as BaseMessage classes (e.g.
SystemMessage, HumanMessage, AIMessage, ...).
SystemMessage, HumanMessage, AIMessage, ...). Default is None.
exclude_types: Message types to exclude. Can be specified as string names (e.g.
"system", "human", "ai", ...) or as BaseMessage classes (e.g.
SystemMessage, HumanMessage, AIMessage, ...).
include_ids: Message IDs to include.
exclude_ids: Message IDs to exclude.
SystemMessage, HumanMessage, AIMessage, ...). Default is None.
include_ids: Message IDs to include. Default is None.
exclude_ids: Message IDs to exclude. Default is None.
Returns:
A list of Messages that meets at least one of the incl_* conditions and none
@ -509,10 +527,12 @@ def trim_messages(
strategy: Strategy for trimming.
- "first": Keep the first <= n_count tokens of the messages.
- "last": Keep the last <= n_count tokens of the messages.
Default is "last".
allow_partial: Whether to split a message if only part of the message can be
included. If ``strategy="last"`` then the last partial contents of a message
are included. If ``strategy="first"`` then the first partial contents of a
message are included.
Default is False.
end_on: The message type to end on. If specified then every message after the
last occurrence of this type is ignored. If ``strategy=="last"`` then this
is done before we attempt to get the last ``max_tokens``. If
@ -520,6 +540,7 @@ def trim_messages(
``max_tokens``. Can be specified as string names (e.g. "system", "human",
"ai", ...) or as BaseMessage classes (e.g. SystemMessage, HumanMessage,
AIMessage, ...). Can be a single type or a list of types.
Default is None.
start_on: The message type to start on. Should only be specified if
``strategy="last"``. If specified then every message before
the first occurrence of this type is ignored. This is done after we trim
@ -528,8 +549,10 @@ def trim_messages(
specified as string names (e.g. "system", "human", "ai", ...) or as
BaseMessage classes (e.g. SystemMessage, HumanMessage, AIMessage, ...). Can
be a single type or a list of types.
Default is None.
include_system: Whether to keep the SystemMessage if there is one at index 0.
Should only be specified if ``strategy="last"``.
Default is False.
text_splitter: Function or ``langchain_text_splitters.TextSplitter`` for
splitting the string contents of a message. Only used if
``allow_partial=True``. If ``strategy="last"`` then the last split tokens