core[patch]: retain string type coercion for various IDs (#26429)

- `BaseMedia.id`
- `BaseMessage.id`
- `ToolMessage.tool_call_id`

Note: with this change we are actually less restrictive on types for
initializing `id` than in pydantic V1-- e.g., pydantic V1 will error if
you pass a dict or tuple to `id`). Let me know if we should restrict
type coercion. For tool_call_id I just enumerated supported types.
This commit is contained in:
ccurme
2024-09-13 11:26:18 -04:00
committed by GitHub
parent 608c4a4327
commit 05bd2b4618
5 changed files with 42 additions and 9 deletions

View File

@@ -6,7 +6,7 @@ from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, model_validator
from pydantic import ConfigDict, Field, field_validator, model_validator
from langchain_core.load.serializable import Serializable
@@ -40,6 +40,13 @@ class BaseMedia(Serializable):
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata associated with the content."""
@field_validator("id", mode="before")
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
if id_value is not None:
return str(id_value)
else:
return id_value
class Blob(BaseMedia):
"""Blob represents raw data by either reference or value.

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
from pydantic import ConfigDict, Field
from pydantic import ConfigDict, Field, field_validator
from langchain_core.load.serializable import Serializable
from langchain_core.utils import get_bolded_text
@@ -56,6 +56,13 @@ class BaseMessage(Serializable):
extra="allow",
)
@field_validator("id", mode="before")
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
if id_value is not None:
return str(id_value)
else:
return id_value
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
) -> None:

View File

@@ -118,7 +118,7 @@ class ToolMessage(BaseMessage):
pass
tool_call_id = values["tool_call_id"]
if isinstance(tool_call_id, UUID):
if isinstance(tool_call_id, (UUID, int, float)):
values["tool_call_id"] = str(tool_call_id)
return values

View File

@@ -0,0 +1,12 @@
from langchain_core.documents import Document
def test_init() -> None:
for doc in [
Document(page_content="foo"),
Document(page_content="foo", metadata={"a": 1}),
Document(page_content="foo", id=None),
Document(page_content="foo", id="1"),
Document(page_content="foo", id=1),
]:
assert isinstance(doc, Document)

View File

@@ -3,13 +3,13 @@ import uuid
from typing import List, Type, Union
import pytest
from pydantic import ValidationError
from langchain_core.documents import Document
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
@@ -33,6 +33,16 @@ from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chu
from langchain_core.utils._merge import merge_lists
def test_message_init() -> None:
for doc in [
BaseMessage(type="foo", content="bar"),
BaseMessage(type="foo", content="bar", id=None),
BaseMessage(type="foo", content="bar", id="1"),
BaseMessage(type="foo", content="bar", id=1),
]:
assert isinstance(doc, BaseMessage)
def test_message_chunks() -> None:
assert AIMessageChunk(content="I am", id="ai3") + AIMessageChunk(
content=" indeed."
@@ -1001,9 +1011,6 @@ def test_tool_message_content() -> None:
def test_tool_message_tool_call_id() -> None:
ToolMessage("foo", tool_call_id="1")
# Currently we only handle UUID->str coercion manually.
ToolMessage("foo", tool_call_id=uuid.uuid4())
with pytest.raises(ValidationError):
ToolMessage("foo", tool_call_id=1)
ToolMessage("foo", tool_call_id=1)
ToolMessage("foo", tool_call_id=1.0)