move best-effort v1 conversion

This commit is contained in:
Chester Curme 2025-07-24 13:31:27 -04:00
parent 2d031031e3
commit b94f23883f
2 changed files with 86 additions and 23 deletions

View File

@ -58,9 +58,7 @@ from langchain_core.messages import (
is_data_content_block, is_data_content_block,
message_chunk_to_message, message_chunk_to_message,
) )
from langchain_core.messages import content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX from langchain_core.messages.ai import _LC_ID_PREFIX
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
ChatGenerationChunk, ChatGenerationChunk,
@ -222,23 +220,6 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
return ls_structured_output_format_dict return ls_structured_output_format_dict
def _convert_to_v1(message: AIMessage) -> AIMessageV1:
"""Best-effort conversion of a V0 AIMessage to V1."""
if isinstance(message.content, str):
content: list[types.ContentBlock] = []
if message.content:
content = [{"type": "text", "text": message.content}]
for tool_call in message.tool_calls:
content.append(tool_call)
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=message.response_metadata,
)
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for chat models. """Base class for chat models.

View File

@ -8,9 +8,9 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, cast from typing import Literal, Union, cast
from typing_extensions import TypedDict from typing_extensions import TypedDict, overload
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.messages import ( from langchain_core.messages import (
@ -19,6 +19,52 @@ from langchain_core.messages import (
HumanMessage, HumanMessage,
get_buffer_string, get_buffer_string,
) )
from langchain_core.messages import content_blocks as types
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
from langchain_core.messages.v1 import MessageV1
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
def _convert_to_v1(message: BaseMessage) -> MessageV1:
"""Best-effort conversion of a V0 AIMessage to V1."""
if isinstance(message.content, str):
content: list[types.ContentBlock] = []
if message.content:
content = [{"type": "text", "text": message.content}]
else:
content = []
for block in message.content:
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif isinstance(block, dict):
content.append(block)
else:
pass
if isinstance(message, HumanMessage):
return HumanMessageV1(content=content)
if isinstance(message, AIMessage):
for tool_call in message.tool_calls:
content.append(tool_call)
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=message.response_metadata,
tool_calls=message.tool_calls,
)
if isinstance(message, SystemMessage):
return SystemMessageV1(content=content)
if isinstance(message, ToolMessage):
return ToolMessageV1(
tool_call_id=message.tool_call_id,
content=content,
artifact=message.artifact,
)
error_message = f"Unsupported message type: {type(message)}"
raise TypeError(error_message)
class PromptValue(Serializable, ABC): class PromptValue(Serializable, ABC):
@ -75,6 +121,26 @@ class StringPromptValue(PromptValue):
"""Return prompt as messages.""" """Return prompt as messages."""
return [HumanMessage(content=self.text)] return [HumanMessage(content=self.text)]
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[list[BaseMessage], list[MessageV1]]:
"""Return prompt as a list of messages.
Args:
output_version: The output version, either "v0" (default) or "v1".
"""
if output_version == "v1":
return [HumanMessageV1(content=self.text)]
return [HumanMessage(content=self.text)]
class ChatPromptValue(PromptValue): class ChatPromptValue(PromptValue):
"""Chat prompt value. """Chat prompt value.
@ -89,8 +155,24 @@ class ChatPromptValue(PromptValue):
"""Return prompt as string.""" """Return prompt as string."""
return get_buffer_string(self.messages) return get_buffer_string(self.messages)
def to_messages(self) -> list[BaseMessage]: @overload
"""Return prompt as a list of messages.""" def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[list[BaseMessage], list[MessageV1]]:
"""Return prompt as a list of messages.
Args:
output_version: The output version, either "v0" (default) or "v1".
"""
if output_version == "v1":
return [_convert_to_v1(m) for m in self.messages]
return list(self.messages) return list(self.messages)
@classmethod @classmethod