mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 22:05:29 +00:00
@@ -79,4 +79,4 @@ class InMemoryDocumentIndex(DocumentIndex):
|
|||||||
counts_by_doc.append((document, count))
|
counts_by_doc.append((document, count))
|
||||||
|
|
||||||
counts_by_doc.sort(key=lambda x: x[1], reverse=True)
|
counts_by_doc.sort(key=lambda x: x[1], reverse=True)
|
||||||
return [doc.copy() for doc, count in counts_by_doc[: self.top_k]]
|
return [doc.model_copy() for doc, count in counts_by_doc[: self.top_k]]
|
||||||
|
@@ -237,7 +237,7 @@ def message_to_dict(message: BaseMessage) -> dict:
|
|||||||
Message as a dict. The dict will have a "type" key with the message type
|
Message as a dict. The dict will have a "type" key with the message type
|
||||||
and a "data" key with the message data as a dict.
|
and a "data" key with the message data as a dict.
|
||||||
"""
|
"""
|
||||||
return {"type": message.type, "data": message.dict()}
|
return {"type": message.type, "data": message.model_dump()}
|
||||||
|
|
||||||
|
|
||||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||||
|
@@ -521,7 +521,7 @@ def merge_message_runs(
|
|||||||
messages = convert_to_messages(messages)
|
messages = convert_to_messages(messages)
|
||||||
merged: List[BaseMessage] = []
|
merged: List[BaseMessage] = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
curr = msg.copy(deep=True)
|
curr = msg.model_copy(deep=True)
|
||||||
last = merged.pop() if merged else None
|
last = merged.pop() if merged else None
|
||||||
if not last:
|
if not last:
|
||||||
merged.append(curr)
|
merged.append(curr)
|
||||||
@@ -872,7 +872,7 @@ def _first_max_tokens(
|
|||||||
if idx < len(messages) - 1 and partial_strategy:
|
if idx < len(messages) - 1 and partial_strategy:
|
||||||
included_partial = False
|
included_partial = False
|
||||||
if isinstance(messages[idx].content, list):
|
if isinstance(messages[idx].content, list):
|
||||||
excluded = messages[idx].copy(deep=True)
|
excluded = messages[idx].model_copy(deep=True)
|
||||||
num_block = len(excluded.content)
|
num_block = len(excluded.content)
|
||||||
if partial_strategy == "last":
|
if partial_strategy == "last":
|
||||||
excluded.content = list(reversed(excluded.content))
|
excluded.content = list(reversed(excluded.content))
|
||||||
@@ -886,7 +886,7 @@ def _first_max_tokens(
|
|||||||
if included_partial and partial_strategy == "last":
|
if included_partial and partial_strategy == "last":
|
||||||
excluded.content = list(reversed(excluded.content))
|
excluded.content = list(reversed(excluded.content))
|
||||||
if not included_partial:
|
if not included_partial:
|
||||||
excluded = messages[idx].copy(deep=True)
|
excluded = messages[idx].model_copy(deep=True)
|
||||||
if isinstance(excluded.content, list) and any(
|
if isinstance(excluded.content, list) and any(
|
||||||
isinstance(block, str) or block["type"] == "text"
|
isinstance(block, str) or block["type"] == "text"
|
||||||
for block in messages[idx].content
|
for block in messages[idx].content
|
||||||
@@ -977,11 +977,11 @@ _CHUNK_MSG_MAP = {v: k for k, v in _MSG_CHUNK_MAP.items()}
|
|||||||
|
|
||||||
def _msg_to_chunk(message: BaseMessage) -> BaseMessageChunk:
|
def _msg_to_chunk(message: BaseMessage) -> BaseMessageChunk:
|
||||||
if message.__class__ in _MSG_CHUNK_MAP:
|
if message.__class__ in _MSG_CHUNK_MAP:
|
||||||
return _MSG_CHUNK_MAP[message.__class__](**message.dict(exclude={"type"}))
|
return _MSG_CHUNK_MAP[message.__class__](**message.model_dump(exclude={"type"}))
|
||||||
|
|
||||||
for msg_cls, chunk_cls in _MSG_CHUNK_MAP.items():
|
for msg_cls, chunk_cls in _MSG_CHUNK_MAP.items():
|
||||||
if isinstance(message, msg_cls):
|
if isinstance(message, msg_cls):
|
||||||
return chunk_cls(**message.dict(exclude={"type"}))
|
return chunk_cls(**message.model_dump(exclude={"type"}))
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized message class {message.__class__}. Supported classes are "
|
f"Unrecognized message class {message.__class__}. Supported classes are "
|
||||||
@@ -992,11 +992,11 @@ def _msg_to_chunk(message: BaseMessage) -> BaseMessageChunk:
|
|||||||
def _chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
|
def _chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
|
||||||
if chunk.__class__ in _CHUNK_MSG_MAP:
|
if chunk.__class__ in _CHUNK_MSG_MAP:
|
||||||
return _CHUNK_MSG_MAP[chunk.__class__](
|
return _CHUNK_MSG_MAP[chunk.__class__](
|
||||||
**chunk.dict(exclude={"type", "tool_call_chunks"})
|
**chunk.model_dump(exclude={"type", "tool_call_chunks"})
|
||||||
)
|
)
|
||||||
for chunk_cls, msg_cls in _CHUNK_MSG_MAP.items():
|
for chunk_cls, msg_cls in _CHUNK_MSG_MAP.items():
|
||||||
if isinstance(chunk, chunk_cls):
|
if isinstance(chunk, chunk_cls):
|
||||||
return msg_cls(**chunk.dict(exclude={"type", "tool_call_chunks"}))
|
return msg_cls(**chunk.model_dump(exclude={"type", "tool_call_chunks"}))
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized message chunk class {chunk.__class__}. Supported classes are "
|
f"Unrecognized message chunk class {chunk.__class__}. Supported classes are "
|
||||||
|
@@ -622,7 +622,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
**{
|
**{
|
||||||
**self.dict(),
|
**self.model_dump(),
|
||||||
**{"runnable": new_runnable, "fallbacks": new_fallbacks},
|
**{"runnable": new_runnable, "fallbacks": new_fallbacks},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@@ -37,6 +37,7 @@ from pydantic import (
|
|||||||
model_validator,
|
model_validator,
|
||||||
validate_arguments,
|
validate_arguments,
|
||||||
)
|
)
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
@@ -479,10 +480,20 @@ class ChildTool(BaseTool):
|
|||||||
return tool_input
|
return tool_input
|
||||||
else:
|
else:
|
||||||
if input_args is not None:
|
if input_args is not None:
|
||||||
result = input_args.parse_obj(tool_input)
|
if issubclass(input_args, BaseModel):
|
||||||
|
result = input_args.model_validate(tool_input)
|
||||||
|
result_dict = result.model_dump()
|
||||||
|
elif issubclass(input_args, BaseModelV1):
|
||||||
|
result = input_args.parse_obj(tool_input)
|
||||||
|
result_dict = result.dict()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"args_schema must be a Pydantic BaseModel, "
|
||||||
|
f"got {self.args_schema}"
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
k: getattr(result, k)
|
k: getattr(result, k)
|
||||||
for k, v in result.dict().items()
|
for k, v in result_dict.items()
|
||||||
if k in tool_input
|
if k in tool_input
|
||||||
}
|
}
|
||||||
return tool_input
|
return tool_input
|
||||||
|
@@ -429,8 +429,8 @@ def test_message_chunk_to_message() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert message_chunk_to_message(chunk) == expected
|
assert message_chunk_to_message(chunk) == expected
|
||||||
assert AIMessage(**expected.dict()) == expected
|
assert AIMessage(**expected.model_dump()) == expected
|
||||||
assert AIMessageChunk(**chunk.dict()) == chunk
|
assert AIMessageChunk(**chunk.model_dump()) == chunk
|
||||||
|
|
||||||
|
|
||||||
def test_tool_calls_merge() -> None:
|
def test_tool_calls_merge() -> None:
|
||||||
|
Reference in New Issue
Block a user