From 30fdc2dbe70ac6281162c1a024471b38f88bf4ba Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Wed, 3 Jul 2024 08:25:00 -0700 Subject: [PATCH] core: docstrings `messages` (#23788) Added missed docstrings. Formatted docstrings to the consistent form. --- libs/core/langchain_core/messages/ai.py | 48 +++++++++++++++--- libs/core/langchain_core/messages/base.py | 50 ++++++++++++++++--- libs/core/langchain_core/messages/chat.py | 13 +++-- libs/core/langchain_core/messages/function.py | 10 ++-- libs/core/langchain_core/messages/human.py | 17 +++++-- libs/core/langchain_core/messages/modifier.py | 13 ++++- libs/core/langchain_core/messages/system.py | 16 ++++-- libs/core/langchain_core/messages/tool.py | 11 +++- libs/core/langchain_core/messages/utils.py | 47 ++++++++++++----- 9 files changed, 183 insertions(+), 42 deletions(-) diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 1311549e9c7..693d4545db8 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -73,17 +73,27 @@ class AIMessage(BaseMessage): """ type: Literal["ai"] = "ai" - """The type of the message (used for deserialization).""" + """The type of the message (used for deserialization). Defaults to "ai".""" def __init__( self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any ) -> None: - """Pass in content as positional arg.""" + """Pass in content as positional arg. + + Args: + content: The content of the message. + **kwargs: Additional arguments to pass to the parent class. + """ super().__init__(content=content, **kwargs) @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + + Returns: + The namespace of the langchain object. + Defaults to ["langchain", "schema", "messages"]. + """ return ["langchain", "schema", "messages"] @property @@ -117,7 +127,15 @@ class AIMessage(BaseMessage): return values def pretty_repr(self, html: bool = False) -> str: - """Return a pretty representation of the message.""" + """Return a pretty representation of the message. + + Args: + html: Whether to return an HTML-formatted string. + Defaults to False. + + Returns: + A pretty representation of the message. + """ base = super().pretty_repr(html=html) lines = [] @@ -157,14 +175,21 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): # Ignoring mypy re-assignment here since we're overriding the value # to make sure that the chunk variant can be discriminated from the # non-chunk variant. - type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] + type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore + """The type of the message (used for deserialization). + Defaults to "AIMessageChunk".""" tool_call_chunks: List[ToolCallChunk] = [] """If provided, tool call chunks associated with the message.""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + + Returns: + The namespace of the langchain object. + Defaults to ["langchain", "schema", "messages"]. + """ return ["langchain", "schema", "messages"] @property @@ -177,6 +202,17 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): @root_validator(pre=False, skip_on_failure=True) def init_tool_calls(cls, values: dict) -> dict: + """Initialize tool calls from tool call chunks. + + Args: + values: The values to validate. + + Returns: + The values with tool calls initialized. + + Raises: + ValueError: If the tool call chunks are malformed. + """ if not values["tool_call_chunks"]: if values["tool_calls"]: values["tool_call_chunks"] = [ diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 117a0358c3b..d187d4a20bf 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -57,17 +57,29 @@ class BaseMessage(Serializable): def __init__( self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any ) -> None: - """Pass in content as positional arg.""" + """Pass in content as positional arg. + + Args: + content: The string contents of the message. + **kwargs: Additional fields to pass to the + """ super().__init__(content=content, **kwargs) @classmethod def is_lc_serializable(cls) -> bool: - """Return whether this class is serializable.""" + """Return whether this class is serializable. This is used to determine + whether the class should be included in the langchain schema. + + Returns: + True if the class is serializable, False otherwise. + """ return True @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"]. + """ return ["langchain", "schema", "messages"] def __add__(self, other: Any) -> ChatPromptTemplate: @@ -78,6 +90,15 @@ class BaseMessage(Serializable): return prompt + other def pretty_repr(self, html: bool = False) -> str: + """Get a pretty representation of the message. + + Args: + html: Whether to format the message as HTML. If True, the message will be + formatted with HTML tags. Default is False. + + Returns: + A pretty representation of the message. + """ title = get_msg_title_repr(self.type.title() + " Message", bold=html) # TODO: handle non-string content. if self.name is not None: @@ -95,8 +116,8 @@ def merge_content( """Merge two message contents. Args: - first_content: The first content. - second_content: The second content. + first_content: The first content. Can be a string or a list. + second_content: The second content. Can be a string or a list. Returns: The merged content. @@ -133,7 +154,9 @@ class BaseMessageChunk(BaseMessage): @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"]. + """ return ["langchain", "schema", "messages"] def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore @@ -142,6 +165,16 @@ class BaseMessageChunk(BaseMessage): This functionality is useful to combine message chunks yielded from a streaming model into a complete message. + Args: + other: Another message chunk to concatenate with this one. + + Returns: + A new message chunk that is the concatenation of this message chunk + and the other message chunk. + + Raises: + TypeError: If the other object is not a message chunk. + For example, `AIMessageChunk(content="Hello") + AIMessageChunk(content=" World")` @@ -177,7 +210,8 @@ def message_to_dict(message: BaseMessage) -> dict: message: Message to convert. Returns: - Message as a dict. + Message as a dict. The dict will have a "type" key with the message type + and a "data" key with the message data as a dict. """ return {"type": message.type, "data": message.dict()} @@ -199,7 +233,7 @@ def get_msg_title_repr(title: str, *, bold: bool = False) -> str: Args: title: The title. - bold: Whether to bold the title. + bold: Whether to bold the title. Default is False. Returns: The title representation. diff --git a/libs/core/langchain_core/messages/chat.py b/libs/core/langchain_core/messages/chat.py index 59e36f004fe..dbbe05a2b89 100644 --- a/libs/core/langchain_core/messages/chat.py +++ b/libs/core/langchain_core/messages/chat.py @@ -15,11 +15,13 @@ class ChatMessage(BaseMessage): """The speaker / role of the Message.""" type: Literal["chat"] = "chat" - """The type of the message (used during serialization).""" + """The type of the message (used during serialization). Defaults to "chat".""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"]. + """ return ["langchain", "schema", "messages"] @@ -33,11 +35,14 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk): # to make sure that the chunk variant can be discriminated from the # non-chunk variant. type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore - """The type of the message (used during serialization).""" + """The type of the message (used during serialization). + Defaults to "ChatMessageChunk".""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"]. + """ return ["langchain", "schema", "messages"] def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore diff --git a/libs/core/langchain_core/messages/function.py b/libs/core/langchain_core/messages/function.py index 21a910470a9..5625a3214da 100644 --- a/libs/core/langchain_core/messages/function.py +++ b/libs/core/langchain_core/messages/function.py @@ -23,11 +23,12 @@ class FunctionMessage(BaseMessage): """The name of the function that was executed.""" type: Literal["function"] = "function" - """The type of the message (used for serialization).""" + """The type of the message (used for serialization). Defaults to "function".""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] @@ -41,10 +42,13 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): # to make sure that the chunk variant can be discriminated from the # non-chunk variant. type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment] + """The type of the message (used for serialization). + Defaults to "FunctionMessageChunk".""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore diff --git a/libs/core/langchain_core/messages/human.py b/libs/core/langchain_core/messages/human.py index 1c930e5dc99..18b37e4c041 100644 --- a/libs/core/langchain_core/messages/human.py +++ b/libs/core/langchain_core/messages/human.py @@ -32,19 +32,27 @@ class HumanMessage(BaseMessage): """Use to denote that a message is part of an example conversation. At the moment, this is ignored by most models. Usage is discouraged. + Defaults to False. """ type: Literal["human"] = "human" + """The type of the message (used for serialization). Defaults to "human".""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] def __init__( self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any ) -> None: - """Pass in content as positional arg.""" + """Pass in content as positional arg. + + Args: + content: The string contents of the message. + **kwargs: Additional fields to pass to the message. + """ super().__init__(content=content, **kwargs) @@ -58,8 +66,11 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk): # to make sure that the chunk variant can be discriminated from the # non-chunk variant. type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] + """The type of the message (used for serialization). + Defaults to "HumanMessageChunk".""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] diff --git a/libs/core/langchain_core/messages/modifier.py b/libs/core/langchain_core/messages/modifier.py index 3ea753892d8..7eff055aa02 100644 --- a/libs/core/langchain_core/messages/modifier.py +++ b/libs/core/langchain_core/messages/modifier.py @@ -9,8 +9,18 @@ class RemoveMessage(BaseMessage): """Message responsible for deleting other messages.""" type: Literal["remove"] = "remove" + """The type of the message (used for serialization). Defaults to "remove".""" def __init__(self, id: str, **kwargs: Any) -> None: + """Create a RemoveMessage. + + Args: + id: The ID of the message to remove. + **kwargs: Additional fields to pass to the message. + + Raises: + ValueError: If the 'content' field is passed in kwargs. + """ if kwargs.pop("content", None): raise ValueError("RemoveMessage does not support 'content' field.") @@ -18,7 +28,8 @@ class RemoveMessage(BaseMessage): @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] diff --git a/libs/core/langchain_core/messages/system.py b/libs/core/langchain_core/messages/system.py index a2ab0adf982..3e3255a9994 100644 --- a/libs/core/langchain_core/messages/system.py +++ b/libs/core/langchain_core/messages/system.py @@ -30,16 +30,23 @@ class SystemMessage(BaseMessage): """ type: Literal["system"] = "system" + """The type of the message (used for serialization). Defaults to "system".""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] def __init__( self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any ) -> None: - """Pass in content as positional arg.""" + """Pass in content as positional arg. + + Args: + content: The string contents of the message. + **kwargs: Additional fields to pass to the message. + """ super().__init__(content=content, **kwargs) @@ -53,8 +60,11 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk): # to make sure that the chunk variant can be discriminated from the # non-chunk variant. type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] + """The type of the message (used for serialization). + Defaults to "SystemMessageChunk".""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 9d91e97aab0..9855f743ef8 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -37,16 +37,23 @@ class ToolMessage(BaseMessage): # """Whether the tool errored.""" type: Literal["tool"] = "tool" + """The type of the message (used for serialization). Defaults to "tool".""" @classmethod def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" + """Get the namespace of the langchain object. + Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] def __init__( self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any ) -> None: - """Pass in content as positional arg.""" + """Pass in content as positional arg. + + Args: + content: The string contents of the message. + **kwargs: Additional fields to pass to the message + """ super().__init__(content=content, **kwargs) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index edda11d1554..ce04ec8a86d 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -6,6 +6,7 @@ Some examples of what you can do with these functions include: * Convert messages from dicts to Message objects (deserialization) * Filter messages from a list of messages based on name, type or id etc. """ + from __future__ import annotations import inspect @@ -53,11 +54,15 @@ def get_buffer_string( Args: messages: Messages to be converted to strings. human_prefix: The prefix to prepend to contents of HumanMessages. - ai_prefix: THe prefix to prepend to contents of AIMessages. + Default is "Human". + ai_prefix: THe prefix to prepend to contents of AIMessages. Default is "AI". Returns: A single string concatenation of all input messages. + Raises: + ValueError: If an unsupported message type is encountered. + Example: .. code-block:: python @@ -173,11 +178,20 @@ def _create_message_from_message_type( """Create a message from a message type and content string. Args: - message_type: str the type of the message (e.g., "human", "ai", etc.) - content: str the content string. + message_type: (str) the type of the message (e.g., "human", "ai", etc.). + content: (str) the content string. + name: (str) the name of the message. Default is None. + tool_call_id: (str) the tool call id. Default is None. + tool_calls: (List[Dict[str, Any]]) the tool calls. Default is None. + id: (str) the id of the message. Default is None. + **additional_kwargs: (Dict[str, Any]) additional keyword arguments. Returns: a message of the appropriate type. + + Raises: + ValueError: if the message type is not one of "human", "user", "ai", + "assistant", "system", "function", or "tool". """ kwargs: Dict[str, Any] = {} if name is not None: @@ -203,7 +217,7 @@ def _create_message_from_message_type( else: raise ValueError( f"Unexpected message type: {message_type}. Use one of 'human'," - f" 'user', 'ai', 'assistant', or 'system'." + f" 'user', 'ai', 'assistant', 'function', 'tool', or 'system'." ) return message @@ -220,10 +234,14 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: - string: shorthand for ("human", template); e.g., "{user_input}" Args: - message: a representation of a message in one of the supported formats + message: a representation of a message in one of the supported formats. Returns: - an instance of a message or a message template + an instance of a message or a message template. + + Raises: + NotImplementedError: if the message type is not supported. + ValueError: if the message dict does not contain the required keys. """ if isinstance(message, BaseMessage): _message = message @@ -315,16 +333,16 @@ def filter_messages( Args: messages: Sequence Message-like objects to filter. - include_names: Message names to include. - exclude_names: Messages names to exclude. + include_names: Message names to include. Default is None. + exclude_names: Messages names to exclude. Default is None. include_types: Message types to include. Can be specified as string names (e.g. "system", "human", "ai", ...) or as BaseMessage classes (e.g. - SystemMessage, HumanMessage, AIMessage, ...). + SystemMessage, HumanMessage, AIMessage, ...). Default is None. exclude_types: Message types to exclude. Can be specified as string names (e.g. "system", "human", "ai", ...) or as BaseMessage classes (e.g. - SystemMessage, HumanMessage, AIMessage, ...). - include_ids: Message IDs to include. - exclude_ids: Message IDs to exclude. + SystemMessage, HumanMessage, AIMessage, ...). Default is None. + include_ids: Message IDs to include. Default is None. + exclude_ids: Message IDs to exclude. Default is None. Returns: A list of Messages that meets at least one of the incl_* conditions and none @@ -509,10 +527,12 @@ def trim_messages( strategy: Strategy for trimming. - "first": Keep the first <= n_count tokens of the messages. - "last": Keep the last <= n_count tokens of the messages. + Default is "last". allow_partial: Whether to split a message if only part of the message can be included. If ``strategy="last"`` then the last partial contents of a message are included. If ``strategy="first"`` then the first partial contents of a message are included. + Default is False. end_on: The message type to end on. If specified then every message after the last occurrence of this type is ignored. If ``strategy=="last"`` then this is done before we attempt to get the last ``max_tokens``. If @@ -520,6 +540,7 @@ def trim_messages( ``max_tokens``. Can be specified as string names (e.g. "system", "human", "ai", ...) or as BaseMessage classes (e.g. SystemMessage, HumanMessage, AIMessage, ...). Can be a single type or a list of types. + Default is None. start_on: The message type to start on. Should only be specified if ``strategy="last"``. If specified then every message before the first occurrence of this type is ignored. This is done after we trim @@ -528,8 +549,10 @@ def trim_messages( specified as string names (e.g. "system", "human", "ai", ...) or as BaseMessage classes (e.g. SystemMessage, HumanMessage, AIMessage, ...). Can be a single type or a list of types. + Default is None. include_system: Whether to keep the SystemMessage if there is one at index 0. Should only be specified if ``strategy="last"``. + Default is False. text_splitter: Function or ``langchain_text_splitters.TextSplitter`` for splitting the string contents of a message. Only used if ``allow_partial=True``. If ``strategy="last"`` then the last split tokens