mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 22:05:29 +00:00
core[patch]: manually coerce ToolMessage args (#26283)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, Union
|
||||
@@ -106,6 +107,13 @@ class BSHTMLLoader(BaseLoader):
|
||||
self.file_path = file_path
|
||||
self.open_encoding = open_encoding
|
||||
if bs_kwargs is None:
|
||||
if not importlib.util.find_spec("lxml"):
|
||||
raise ImportError(
|
||||
"By default BSHTMLLoader uses the 'lxml' package. Please either "
|
||||
"install it with `pip install -U lxml` or pass in init arg "
|
||||
"`bs_kwargs={'features': '...'}` to overwrite the default "
|
||||
"BeautifulSoup kwargs."
|
||||
)
|
||||
bs_kwargs = {"features": "lxml"}
|
||||
self.bs_kwargs = bs_kwargs
|
||||
self.get_text_separator = get_text_separator
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
@@ -82,15 +83,48 @@ class ToolMessage(BaseMessage):
|
||||
Default is ["langchain", "schema", "messages"]."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def coerce_args(cls, values: dict) -> dict:
|
||||
content = values["content"]
|
||||
if isinstance(content, tuple):
|
||||
content = list(content)
|
||||
|
||||
if not isinstance(content, (str, list)):
|
||||
try:
|
||||
values["content"] = str(content)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"ToolMessage content should be a string or a list of string/dicts. "
|
||||
f"Received:\n\n{content=}\n\n which could not be coerced into a "
|
||||
"string."
|
||||
) from e
|
||||
elif isinstance(content, list):
|
||||
values["content"] = []
|
||||
for i, x in enumerate(content):
|
||||
if not isinstance(x, (str, dict)):
|
||||
try:
|
||||
values["content"].append(str(x))
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"ToolMessage content should be a string or a list of "
|
||||
"string/dicts. Received a list but "
|
||||
f"element ToolMessage.content[{i}] is not a dict and could "
|
||||
f"not be coerced to a string.:\n\n{x}"
|
||||
) from e
|
||||
else:
|
||||
values["content"].append(x)
|
||||
else:
|
||||
pass
|
||||
|
||||
tool_call_id = values["tool_call_id"]
|
||||
if isinstance(tool_call_id, UUID):
|
||||
values["tool_call_id"] = str(tool_call_id)
|
||||
return values
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
|
||||
Args:
|
||||
content: The string contents of the message.
|
||||
kwargs: Additional fields to pass to the message
|
||||
"""
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
|
@@ -1,8 +1,11 @@
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import List, Type, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -980,3 +983,27 @@ def test_merge_content(
|
||||
) -> None:
|
||||
actual = merge_content(first, *others)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_tool_message_content() -> None:
|
||||
ToolMessage("foo", tool_call_id="1")
|
||||
ToolMessage(["foo"], tool_call_id="1")
|
||||
ToolMessage([{"foo": "bar"}], tool_call_id="1")
|
||||
|
||||
assert ToolMessage(("a", "b", "c"), tool_call_id="1").content == ["a", "b", "c"] # type: ignore[arg-type]
|
||||
assert ToolMessage(5, tool_call_id="1").content == "5" # type: ignore[arg-type]
|
||||
assert ToolMessage(5.1, tool_call_id="1").content == "5.1" # type: ignore[arg-type]
|
||||
assert ToolMessage({"foo": "bar"}, tool_call_id="1").content == "{'foo': 'bar'}" # type: ignore[arg-type]
|
||||
assert (
|
||||
ToolMessage(Document("foo"), tool_call_id="1").content == "page_content='foo'" # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def test_tool_message_tool_call_id() -> None:
|
||||
ToolMessage("foo", tool_call_id="1")
|
||||
|
||||
# Currently we only handle UUID->str coercion manually.
|
||||
ToolMessage("foo", tool_call_id=uuid.uuid4())
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ToolMessage("foo", tool_call_id=1)
|
||||
|
Reference in New Issue
Block a user