mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 03:02:35 +00:00
move best-effort v1 conversion
This commit is contained in:
parent
2d031031e3
commit
b94f23883f
@ -58,9 +58,7 @@ from langchain_core.messages import (
|
|||||||
is_data_content_block,
|
is_data_content_block,
|
||||||
message_chunk_to_message,
|
message_chunk_to_message,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import content_blocks as types
|
|
||||||
from langchain_core.messages.ai import _LC_ID_PREFIX
|
from langchain_core.messages.ai import _LC_ID_PREFIX
|
||||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
@ -222,23 +220,6 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
|
|||||||
return ls_structured_output_format_dict
|
return ls_structured_output_format_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_v1(message: AIMessage) -> AIMessageV1:
|
|
||||||
"""Best-effort conversion of a V0 AIMessage to V1."""
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
content: list[types.ContentBlock] = []
|
|
||||||
if message.content:
|
|
||||||
content = [{"type": "text", "text": message.content}]
|
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
content.append(tool_call)
|
|
||||||
|
|
||||||
return AIMessageV1(
|
|
||||||
content=content,
|
|
||||||
usage_metadata=message.usage_metadata,
|
|
||||||
response_metadata=message.response_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||||
"""Base class for chat models.
|
"""Base class for chat models.
|
||||||
|
|
||||||
|
@ -8,9 +8,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal, cast
|
from typing import Literal, Union, cast
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict, overload
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -19,6 +19,52 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages import content_blocks as types
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1
|
||||||
|
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
|
||||||
|
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_v1(message: BaseMessage) -> MessageV1:
|
||||||
|
"""Best-effort conversion of a V0 AIMessage to V1."""
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
content: list[types.ContentBlock] = []
|
||||||
|
if message.content:
|
||||||
|
content = [{"type": "text", "text": message.content}]
|
||||||
|
else:
|
||||||
|
content = []
|
||||||
|
for block in message.content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
content.append({"type": "text", "text": block})
|
||||||
|
elif isinstance(block, dict):
|
||||||
|
content.append(block)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(message, HumanMessage):
|
||||||
|
return HumanMessageV1(content=content)
|
||||||
|
if isinstance(message, AIMessage):
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
content.append(tool_call)
|
||||||
|
return AIMessageV1(
|
||||||
|
content=content,
|
||||||
|
usage_metadata=message.usage_metadata,
|
||||||
|
response_metadata=message.response_metadata,
|
||||||
|
tool_calls=message.tool_calls,
|
||||||
|
)
|
||||||
|
if isinstance(message, SystemMessage):
|
||||||
|
return SystemMessageV1(content=content)
|
||||||
|
if isinstance(message, ToolMessage):
|
||||||
|
return ToolMessageV1(
|
||||||
|
tool_call_id=message.tool_call_id,
|
||||||
|
content=content,
|
||||||
|
artifact=message.artifact,
|
||||||
|
)
|
||||||
|
error_message = f"Unsupported message type: {type(message)}"
|
||||||
|
raise TypeError(error_message)
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(Serializable, ABC):
|
class PromptValue(Serializable, ABC):
|
||||||
@ -75,6 +121,26 @@ class StringPromptValue(PromptValue):
|
|||||||
"""Return prompt as messages."""
|
"""Return prompt as messages."""
|
||||||
return [HumanMessage(content=self.text)]
|
return [HumanMessage(content=self.text)]
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0"] = "v0"
|
||||||
|
) -> list[BaseMessage]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||||
|
|
||||||
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0", "v1"] = "v0"
|
||||||
|
) -> Union[list[BaseMessage], list[MessageV1]]:
|
||||||
|
"""Return prompt as a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_version: The output version, either "v0" (default) or "v1".
|
||||||
|
"""
|
||||||
|
if output_version == "v1":
|
||||||
|
return [HumanMessageV1(content=self.text)]
|
||||||
|
return [HumanMessage(content=self.text)]
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptValue(PromptValue):
|
class ChatPromptValue(PromptValue):
|
||||||
"""Chat prompt value.
|
"""Chat prompt value.
|
||||||
@ -89,8 +155,24 @@ class ChatPromptValue(PromptValue):
|
|||||||
"""Return prompt as string."""
|
"""Return prompt as string."""
|
||||||
return get_buffer_string(self.messages)
|
return get_buffer_string(self.messages)
|
||||||
|
|
||||||
def to_messages(self) -> list[BaseMessage]:
|
@overload
|
||||||
"""Return prompt as a list of messages."""
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0"] = "v0"
|
||||||
|
) -> list[BaseMessage]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||||
|
|
||||||
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0", "v1"] = "v0"
|
||||||
|
) -> Union[list[BaseMessage], list[MessageV1]]:
|
||||||
|
"""Return prompt as a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_version: The output version, either "v0" (default) or "v1".
|
||||||
|
"""
|
||||||
|
if output_version == "v1":
|
||||||
|
return [_convert_to_v1(m) for m in self.messages]
|
||||||
return list(self.messages)
|
return list(self.messages)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user