mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 22:05:29 +00:00
@@ -79,4 +79,4 @@ class InMemoryDocumentIndex(DocumentIndex):
|
||||
counts_by_doc.append((document, count))
|
||||
|
||||
counts_by_doc.sort(key=lambda x: x[1], reverse=True)
|
||||
return [doc.copy() for doc, count in counts_by_doc[: self.top_k]]
|
||||
return [doc.model_copy() for doc, count in counts_by_doc[: self.top_k]]
|
||||
|
@@ -237,7 +237,7 @@ def message_to_dict(message: BaseMessage) -> dict:
|
||||
Message as a dict. The dict will have a "type" key with the message type
|
||||
and a "data" key with the message data as a dict.
|
||||
"""
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
return {"type": message.type, "data": message.model_dump()}
|
||||
|
||||
|
||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||
|
@@ -521,7 +521,7 @@ def merge_message_runs(
|
||||
messages = convert_to_messages(messages)
|
||||
merged: List[BaseMessage] = []
|
||||
for msg in messages:
|
||||
curr = msg.copy(deep=True)
|
||||
curr = msg.model_copy(deep=True)
|
||||
last = merged.pop() if merged else None
|
||||
if not last:
|
||||
merged.append(curr)
|
||||
@@ -872,7 +872,7 @@ def _first_max_tokens(
|
||||
if idx < len(messages) - 1 and partial_strategy:
|
||||
included_partial = False
|
||||
if isinstance(messages[idx].content, list):
|
||||
excluded = messages[idx].copy(deep=True)
|
||||
excluded = messages[idx].model_copy(deep=True)
|
||||
num_block = len(excluded.content)
|
||||
if partial_strategy == "last":
|
||||
excluded.content = list(reversed(excluded.content))
|
||||
@@ -886,7 +886,7 @@ def _first_max_tokens(
|
||||
if included_partial and partial_strategy == "last":
|
||||
excluded.content = list(reversed(excluded.content))
|
||||
if not included_partial:
|
||||
excluded = messages[idx].copy(deep=True)
|
||||
excluded = messages[idx].model_copy(deep=True)
|
||||
if isinstance(excluded.content, list) and any(
|
||||
isinstance(block, str) or block["type"] == "text"
|
||||
for block in messages[idx].content
|
||||
@@ -977,11 +977,11 @@ _CHUNK_MSG_MAP = {v: k for k, v in _MSG_CHUNK_MAP.items()}
|
||||
|
||||
def _msg_to_chunk(message: BaseMessage) -> BaseMessageChunk:
|
||||
if message.__class__ in _MSG_CHUNK_MAP:
|
||||
return _MSG_CHUNK_MAP[message.__class__](**message.dict(exclude={"type"}))
|
||||
return _MSG_CHUNK_MAP[message.__class__](**message.model_dump(exclude={"type"}))
|
||||
|
||||
for msg_cls, chunk_cls in _MSG_CHUNK_MAP.items():
|
||||
if isinstance(message, msg_cls):
|
||||
return chunk_cls(**message.dict(exclude={"type"}))
|
||||
return chunk_cls(**message.model_dump(exclude={"type"}))
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized message class {message.__class__}. Supported classes are "
|
||||
@@ -992,11 +992,11 @@ def _msg_to_chunk(message: BaseMessage) -> BaseMessageChunk:
|
||||
def _chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
|
||||
if chunk.__class__ in _CHUNK_MSG_MAP:
|
||||
return _CHUNK_MSG_MAP[chunk.__class__](
|
||||
**chunk.dict(exclude={"type", "tool_call_chunks"})
|
||||
**chunk.model_dump(exclude={"type", "tool_call_chunks"})
|
||||
)
|
||||
for chunk_cls, msg_cls in _CHUNK_MSG_MAP.items():
|
||||
if isinstance(chunk, chunk_cls):
|
||||
return msg_cls(**chunk.dict(exclude={"type", "tool_call_chunks"}))
|
||||
return msg_cls(**chunk.model_dump(exclude={"type", "tool_call_chunks"}))
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized message chunk class {chunk.__class__}. Supported classes are "
|
||||
|
@@ -622,7 +622,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
|
||||
return self.__class__(
|
||||
**{
|
||||
**self.dict(),
|
||||
**self.model_dump(),
|
||||
**{"runnable": new_runnable, "fallbacks": new_fallbacks},
|
||||
}
|
||||
)
|
||||
|
@@ -37,6 +37,7 @@ from pydantic import (
|
||||
model_validator,
|
||||
validate_arguments,
|
||||
)
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
@@ -479,10 +480,20 @@ class ChildTool(BaseTool):
|
||||
return tool_input
|
||||
else:
|
||||
if input_args is not None:
|
||||
result = input_args.parse_obj(tool_input)
|
||||
if issubclass(input_args, BaseModel):
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump()
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
result = input_args.parse_obj(tool_input)
|
||||
result_dict = result.dict()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"args_schema must be a Pydantic BaseModel, "
|
||||
f"got {self.args_schema}"
|
||||
)
|
||||
return {
|
||||
k: getattr(result, k)
|
||||
for k, v in result.dict().items()
|
||||
for k, v in result_dict.items()
|
||||
if k in tool_input
|
||||
}
|
||||
return tool_input
|
||||
|
@@ -429,8 +429,8 @@ def test_message_chunk_to_message() -> None:
|
||||
],
|
||||
)
|
||||
assert message_chunk_to_message(chunk) == expected
|
||||
assert AIMessage(**expected.dict()) == expected
|
||||
assert AIMessageChunk(**chunk.dict()) == chunk
|
||||
assert AIMessage(**expected.model_dump()) == expected
|
||||
assert AIMessageChunk(**chunk.model_dump()) == chunk
|
||||
|
||||
|
||||
def test_tool_calls_merge() -> None:
|
||||
|
Reference in New Issue
Block a user