core[patch]: manually coerce ToolMessage args (#26283)

This commit is contained in:
Bagatur
2024-09-10 15:57:57 -07:00
committed by GitHub
parent fce9322d2e
commit aa9f247803
10 changed files with 905 additions and 214 deletions

View File

@@ -1,3 +1,4 @@
import importlib.util
import logging
from pathlib import Path
from typing import Dict, Iterator, Union
@@ -106,6 +107,13 @@ class BSHTMLLoader(BaseLoader):
self.file_path = file_path
self.open_encoding = open_encoding
if bs_kwargs is None:
if not importlib.util.find_spec("lxml"):
raise ImportError(
"By default BSHTMLLoader uses the 'lxml' package. Please either "
"install it with `pip install -U lxml` or pass in init arg "
"`bs_kwargs={'features': '...'}` to overwrite the default "
"BeautifulSoup kwargs."
)
bs_kwargs = {"features": "lxml"}
self.bs_kwargs = bs_kwargs
self.get_text_separator = get_text_separator

View File

@@ -1,7 +1,8 @@
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from uuid import UUID
from pydantic import Field
from pydantic import Field, model_validator
from typing_extensions import NotRequired, TypedDict
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
@@ -82,15 +83,48 @@ class ToolMessage(BaseMessage):
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
@model_validator(mode="before")
@classmethod
def coerce_args(cls, values: dict) -> dict:
content = values["content"]
if isinstance(content, tuple):
content = list(content)
if not isinstance(content, (str, list)):
try:
values["content"] = str(content)
except ValueError as e:
raise ValueError(
"ToolMessage content should be a string or a list of string/dicts. "
f"Received:\n\n{content=}\n\n which could not be coerced into a "
"string."
) from e
elif isinstance(content, list):
values["content"] = []
for i, x in enumerate(content):
if not isinstance(x, (str, dict)):
try:
values["content"].append(str(x))
except ValueError as e:
raise ValueError(
"ToolMessage content should be a string or a list of "
"string/dicts. Received a list but "
f"element ToolMessage.content[{i}] is not a dict and could "
f"not be coerced to a string.:\n\n{x}"
) from e
else:
values["content"].append(x)
else:
pass
tool_call_id = values["tool_call_id"]
if isinstance(tool_call_id, UUID):
values["tool_call_id"] = str(tool_call_id)
return values
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg.
Args:
content: The string contents of the message.
kwargs: Additional fields to pass to the message
"""
super().__init__(content=content, **kwargs)

View File

@@ -1,8 +1,11 @@
import unittest
import uuid
from typing import List, Type, Union
import pytest
from pydantic import ValidationError
from langchain_core.documents import Document
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
@@ -980,3 +983,27 @@ def test_merge_content(
) -> None:
actual = merge_content(first, *others)
assert actual == expected
def test_tool_message_content() -> None:
ToolMessage("foo", tool_call_id="1")
ToolMessage(["foo"], tool_call_id="1")
ToolMessage([{"foo": "bar"}], tool_call_id="1")
assert ToolMessage(("a", "b", "c"), tool_call_id="1").content == ["a", "b", "c"] # type: ignore[arg-type]
assert ToolMessage(5, tool_call_id="1").content == "5" # type: ignore[arg-type]
assert ToolMessage(5.1, tool_call_id="1").content == "5.1" # type: ignore[arg-type]
assert ToolMessage({"foo": "bar"}, tool_call_id="1").content == "{'foo': 'bar'}" # type: ignore[arg-type]
assert (
ToolMessage(Document("foo"), tool_call_id="1").content == "page_content='foo'" # type: ignore[arg-type]
)
def test_tool_message_tool_call_id() -> None:
ToolMessage("foo", tool_call_id="1")
# Currently we only handle UUID->str coercion manually.
ToolMessage("foo", tool_call_id=uuid.uuid4())
with pytest.raises(ValidationError):
ToolMessage("foo", tool_call_id=1)