mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
core[patch]: retain string type coercion for various IDs (#26429)
- `BaseMedia.id` - `BaseMessage.id` - `ToolMessage.tool_call_id` Note: with this change we are actually less restrictive on types for initializing `id` than in pydantic V1-- e.g., pydantic V1 will error if you pass a dict or tuple to `id`). Let me know if we should restrict type coercion. For tool_call_id I just enumerated supported types.
This commit is contained in:
@@ -6,7 +6,7 @@ from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
@@ -40,6 +40,13 @@ class BaseMedia(Serializable):
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Arbitrary metadata associated with the content."""
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
|
||||
if id_value is not None:
|
||||
return str(id_value)
|
||||
else:
|
||||
return id_value
|
||||
|
||||
|
||||
class Blob(BaseMedia):
|
||||
"""Blob represents raw data by either reference or value.
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.utils import get_bolded_text
|
||||
@@ -56,6 +56,13 @@ class BaseMessage(Serializable):
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
|
||||
if id_value is not None:
|
||||
return str(id_value)
|
||||
else:
|
||||
return id_value
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
|
||||
@@ -118,7 +118,7 @@ class ToolMessage(BaseMessage):
|
||||
pass
|
||||
|
||||
tool_call_id = values["tool_call_id"]
|
||||
if isinstance(tool_call_id, UUID):
|
||||
if isinstance(tool_call_id, (UUID, int, float)):
|
||||
values["tool_call_id"] = str(tool_call_id)
|
||||
return values
|
||||
|
||||
|
||||
12
libs/core/tests/unit_tests/documents/test_document.py
Normal file
12
libs/core/tests/unit_tests/documents/test_document.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
def test_init() -> None:
|
||||
for doc in [
|
||||
Document(page_content="foo"),
|
||||
Document(page_content="foo", metadata={"a": 1}),
|
||||
Document(page_content="foo", id=None),
|
||||
Document(page_content="foo", id="1"),
|
||||
Document(page_content="foo", id=1),
|
||||
]:
|
||||
assert isinstance(doc, Document)
|
||||
@@ -3,13 +3,13 @@ import uuid
|
||||
from typing import List, Type, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
@@ -33,6 +33,16 @@ from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chu
|
||||
from langchain_core.utils._merge import merge_lists
|
||||
|
||||
|
||||
def test_message_init() -> None:
|
||||
for doc in [
|
||||
BaseMessage(type="foo", content="bar"),
|
||||
BaseMessage(type="foo", content="bar", id=None),
|
||||
BaseMessage(type="foo", content="bar", id="1"),
|
||||
BaseMessage(type="foo", content="bar", id=1),
|
||||
]:
|
||||
assert isinstance(doc, BaseMessage)
|
||||
|
||||
|
||||
def test_message_chunks() -> None:
|
||||
assert AIMessageChunk(content="I am", id="ai3") + AIMessageChunk(
|
||||
content=" indeed."
|
||||
@@ -1001,9 +1011,6 @@ def test_tool_message_content() -> None:
|
||||
|
||||
def test_tool_message_tool_call_id() -> None:
|
||||
ToolMessage("foo", tool_call_id="1")
|
||||
|
||||
# Currently we only handle UUID->str coercion manually.
|
||||
ToolMessage("foo", tool_call_id=uuid.uuid4())
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ToolMessage("foo", tool_call_id=1)
|
||||
ToolMessage("foo", tool_call_id=1)
|
||||
ToolMessage("foo", tool_call_id=1.0)
|
||||
|
||||
Reference in New Issue
Block a user