From 05bd2b461818316bbfed55f822599ab2cd5ca653 Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 13 Sep 2024 11:26:18 -0400 Subject: [PATCH] core[patch]: retain string type coercion for various IDs (#26429) - `BaseMedia.id` - `BaseMessage.id` - `ToolMessage.tool_call_id` Note: with this change we are actually less restrictive on types for initializing `id` than in pydantic V1-- e.g., pydantic V1 will error if you pass a dict or tuple to `id`). Let me know if we should restrict type coercion. For tool_call_id I just enumerated supported types. --- libs/core/langchain_core/documents/base.py | 9 ++++++++- libs/core/langchain_core/messages/base.py | 9 ++++++++- libs/core/langchain_core/messages/tool.py | 2 +- .../unit_tests/documents/test_document.py | 12 ++++++++++++ libs/core/tests/unit_tests/test_messages.py | 19 +++++++++++++------ 5 files changed, 42 insertions(+), 9 deletions(-) create mode 100644 libs/core/tests/unit_tests/documents/test_document.py diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index 6609d9471d4..1586322b1b1 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -6,7 +6,7 @@ from io import BufferedReader, BytesIO from pathlib import PurePath from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast -from pydantic import ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, field_validator, model_validator from langchain_core.load.serializable import Serializable @@ -40,6 +40,13 @@ class BaseMedia(Serializable): metadata: dict = Field(default_factory=dict) """Arbitrary metadata associated with the content.""" + @field_validator("id", mode="before") + def cast_id_to_str(cls, id_value: Any) -> Optional[str]: + if id_value is not None: + return str(id_value) + else: + return id_value + class Blob(BaseMedia): """Blob represents raw data by either reference or value. diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 7d843188812..97d95c010da 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, field_validator from langchain_core.load.serializable import Serializable from langchain_core.utils import get_bolded_text @@ -56,6 +56,13 @@ class BaseMessage(Serializable): extra="allow", ) + @field_validator("id", mode="before") + def cast_id_to_str(cls, id_value: Any) -> Optional[str]: + if id_value is not None: + return str(id_value) + else: + return id_value + def __init__( self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any ) -> None: diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 8dcfccff16f..be1bc675b1a 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -118,7 +118,7 @@ class ToolMessage(BaseMessage): pass tool_call_id = values["tool_call_id"] - if isinstance(tool_call_id, UUID): + if isinstance(tool_call_id, (UUID, int, float)): values["tool_call_id"] = str(tool_call_id) return values diff --git a/libs/core/tests/unit_tests/documents/test_document.py b/libs/core/tests/unit_tests/documents/test_document.py new file mode 100644 index 00000000000..e312121bd01 --- /dev/null +++ b/libs/core/tests/unit_tests/documents/test_document.py @@ -0,0 +1,12 @@ +from langchain_core.documents import Document + + +def test_init() -> None: + for doc in [ + Document(page_content="foo"), + Document(page_content="foo", metadata={"a": 1}), + Document(page_content="foo", id=None), + Document(page_content="foo", id="1"), + Document(page_content="foo", id=1), + ]: + assert isinstance(doc, Document) diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 225f68451d8..c6c47396df0 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -3,13 +3,13 @@ import uuid from typing import List, Type, Union import pytest -from pydantic import ValidationError from langchain_core.documents import Document from langchain_core.load import dumpd, load from langchain_core.messages import ( AIMessage, AIMessageChunk, + BaseMessage, ChatMessage, ChatMessageChunk, FunctionMessage, @@ -33,6 +33,16 @@ from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chu from langchain_core.utils._merge import merge_lists +def test_message_init() -> None: + for doc in [ + BaseMessage(type="foo", content="bar"), + BaseMessage(type="foo", content="bar", id=None), + BaseMessage(type="foo", content="bar", id="1"), + BaseMessage(type="foo", content="bar", id=1), + ]: + assert isinstance(doc, BaseMessage) + + def test_message_chunks() -> None: assert AIMessageChunk(content="I am", id="ai3") + AIMessageChunk( content=" indeed." @@ -1001,9 +1011,6 @@ def test_tool_message_content() -> None: def test_tool_message_tool_call_id() -> None: ToolMessage("foo", tool_call_id="1") - - # Currently we only handle UUID->str coercion manually. ToolMessage("foo", tool_call_id=uuid.uuid4()) - - with pytest.raises(ValidationError): - ToolMessage("foo", tool_call_id=1) + ToolMessage("foo", tool_call_id=1) + ToolMessage("foo", tool_call_id=1.0)