core[patch]: resolve warnings (#26157)

Resolve a batch of warnings
This commit is contained in:
ccurme
2024-09-06 15:00:53 -04:00
committed by GitHub
parent 1b77063c88
commit c27703a10f
6 changed files with 25 additions and 14 deletions

View File

@@ -79,4 +79,4 @@ class InMemoryDocumentIndex(DocumentIndex):
counts_by_doc.append((document, count))
counts_by_doc.sort(key=lambda x: x[1], reverse=True)
return [doc.copy() for doc, count in counts_by_doc[: self.top_k]]
return [doc.model_copy() for doc, count in counts_by_doc[: self.top_k]]

View File

@@ -237,7 +237,7 @@ def message_to_dict(message: BaseMessage) -> dict:
Message as a dict. The dict will have a "type" key with the message type
and a "data" key with the message data as a dict.
"""
return {"type": message.type, "data": message.dict()}
return {"type": message.type, "data": message.model_dump()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:

View File

@@ -521,7 +521,7 @@ def merge_message_runs(
messages = convert_to_messages(messages)
merged: List[BaseMessage] = []
for msg in messages:
curr = msg.copy(deep=True)
curr = msg.model_copy(deep=True)
last = merged.pop() if merged else None
if not last:
merged.append(curr)
@@ -872,7 +872,7 @@ def _first_max_tokens(
if idx < len(messages) - 1 and partial_strategy:
included_partial = False
if isinstance(messages[idx].content, list):
excluded = messages[idx].copy(deep=True)
excluded = messages[idx].model_copy(deep=True)
num_block = len(excluded.content)
if partial_strategy == "last":
excluded.content = list(reversed(excluded.content))
@@ -886,7 +886,7 @@ def _first_max_tokens(
if included_partial and partial_strategy == "last":
excluded.content = list(reversed(excluded.content))
if not included_partial:
excluded = messages[idx].copy(deep=True)
excluded = messages[idx].model_copy(deep=True)
if isinstance(excluded.content, list) and any(
isinstance(block, str) or block["type"] == "text"
for block in messages[idx].content
@@ -977,11 +977,11 @@ _CHUNK_MSG_MAP = {v: k for k, v in _MSG_CHUNK_MAP.items()}
def _msg_to_chunk(message: BaseMessage) -> BaseMessageChunk:
if message.__class__ in _MSG_CHUNK_MAP:
return _MSG_CHUNK_MAP[message.__class__](**message.dict(exclude={"type"}))
return _MSG_CHUNK_MAP[message.__class__](**message.model_dump(exclude={"type"}))
for msg_cls, chunk_cls in _MSG_CHUNK_MAP.items():
if isinstance(message, msg_cls):
return chunk_cls(**message.dict(exclude={"type"}))
return chunk_cls(**message.model_dump(exclude={"type"}))
raise ValueError(
f"Unrecognized message class {message.__class__}. Supported classes are "
@@ -992,11 +992,11 @@ def _msg_to_chunk(message: BaseMessage) -> BaseMessageChunk:
def _chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
if chunk.__class__ in _CHUNK_MSG_MAP:
return _CHUNK_MSG_MAP[chunk.__class__](
**chunk.dict(exclude={"type", "tool_call_chunks"})
**chunk.model_dump(exclude={"type", "tool_call_chunks"})
)
for chunk_cls, msg_cls in _CHUNK_MSG_MAP.items():
if isinstance(chunk, chunk_cls):
return msg_cls(**chunk.dict(exclude={"type", "tool_call_chunks"}))
return msg_cls(**chunk.model_dump(exclude={"type", "tool_call_chunks"}))
raise ValueError(
f"Unrecognized message chunk class {chunk.__class__}. Supported classes are "

View File

@@ -622,7 +622,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return self.__class__(
**{
**self.dict(),
**self.model_dump(),
**{"runnable": new_runnable, "fallbacks": new_fallbacks},
}
)

View File

@@ -37,6 +37,7 @@ from pydantic import (
model_validator,
validate_arguments,
)
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Annotated
from langchain_core._api import deprecated
@@ -479,10 +480,20 @@ class ChildTool(BaseTool):
return tool_input
else:
if input_args is not None:
result = input_args.parse_obj(tool_input)
if issubclass(input_args, BaseModel):
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
result = input_args.parse_obj(tool_input)
result_dict = result.dict()
else:
raise NotImplementedError(
"args_schema must be a Pydantic BaseModel, "
f"got {self.args_schema}"
)
return {
k: getattr(result, k)
for k, v in result.dict().items()
for k, v in result_dict.items()
if k in tool_input
}
return tool_input

View File

@@ -429,8 +429,8 @@ def test_message_chunk_to_message() -> None:
],
)
assert message_chunk_to_message(chunk) == expected
assert AIMessage(**expected.dict()) == expected
assert AIMessageChunk(**chunk.dict()) == chunk
assert AIMessage(**expected.model_dump()) == expected
assert AIMessageChunk(**chunk.model_dump()) == chunk
def test_tool_calls_merge() -> None: