move best-effort v1 conversion

This commit is contained in:
Chester Curme 2025-07-24 13:31:27 -04:00
parent 2d031031e3
commit b94f23883f
2 changed files with 86 additions and 23 deletions

View File

@ -58,9 +58,7 @@ from langchain_core.messages import (
is_data_content_block,
message_chunk_to_message,
)
from langchain_core.messages import content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
@ -222,23 +220,6 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
return ls_structured_output_format_dict
def _convert_to_v1(message: AIMessage) -> AIMessageV1:
"""Best-effort conversion of a V0 AIMessage to V1."""
if isinstance(message.content, str):
content: list[types.ContentBlock] = []
if message.content:
content = [{"type": "text", "text": message.content}]
for tool_call in message.tool_calls:
content.append(tool_call)
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=message.response_metadata,
)
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for chat models.

View File

@ -8,9 +8,9 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Literal, cast
from typing import Literal, Union, cast
from typing_extensions import TypedDict
from typing_extensions import TypedDict, overload
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
@ -19,6 +19,52 @@ from langchain_core.messages import (
HumanMessage,
get_buffer_string,
)
from langchain_core.messages import content_blocks as types
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
from langchain_core.messages.v1 import MessageV1
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
def _convert_to_v1(message: BaseMessage) -> MessageV1:
"""Best-effort conversion of a V0 AIMessage to V1."""
if isinstance(message.content, str):
content: list[types.ContentBlock] = []
if message.content:
content = [{"type": "text", "text": message.content}]
else:
content = []
for block in message.content:
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif isinstance(block, dict):
content.append(block)
else:
pass
if isinstance(message, HumanMessage):
return HumanMessageV1(content=content)
if isinstance(message, AIMessage):
for tool_call in message.tool_calls:
content.append(tool_call)
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=message.response_metadata,
tool_calls=message.tool_calls,
)
if isinstance(message, SystemMessage):
return SystemMessageV1(content=content)
if isinstance(message, ToolMessage):
return ToolMessageV1(
tool_call_id=message.tool_call_id,
content=content,
artifact=message.artifact,
)
error_message = f"Unsupported message type: {type(message)}"
raise TypeError(error_message)
class PromptValue(Serializable, ABC):
@ -75,6 +121,26 @@ class StringPromptValue(PromptValue):
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[list[BaseMessage], list[MessageV1]]:
"""Return prompt as a list of messages.
Args:
output_version: The output version, either "v0" (default) or "v1".
"""
if output_version == "v1":
return [HumanMessageV1(content=self.text)]
return [HumanMessage(content=self.text)]
class ChatPromptValue(PromptValue):
"""Chat prompt value.
@ -89,8 +155,24 @@ class ChatPromptValue(PromptValue):
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> list[BaseMessage]:
"""Return prompt as a list of messages."""
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[list[BaseMessage], list[MessageV1]]:
"""Return prompt as a list of messages.
Args:
output_version: The output version, either "v0" (default) or "v1".
"""
if output_version == "v1":
return [_convert_to_v1(m) for m in self.messages]
return list(self.messages)
@classmethod