mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 19:15:44 +00:00
move best-effort v1 conversion
This commit is contained in:
parent
2d031031e3
commit
b94f23883f
@ -58,9 +58,7 @@ from langchain_core.messages import (
|
||||
is_data_content_block,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.messages import content_blocks as types
|
||||
from langchain_core.messages.ai import _LC_ID_PREFIX
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
@ -222,23 +220,6 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
|
||||
return ls_structured_output_format_dict
|
||||
|
||||
|
||||
def _convert_to_v1(message: AIMessage) -> AIMessageV1:
|
||||
"""Best-effort conversion of a V0 AIMessage to V1."""
|
||||
if isinstance(message.content, str):
|
||||
content: list[types.ContentBlock] = []
|
||||
if message.content:
|
||||
content = [{"type": "text", "text": message.content}]
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
content.append(tool_call)
|
||||
|
||||
return AIMessageV1(
|
||||
content=content,
|
||||
usage_metadata=message.usage_metadata,
|
||||
response_metadata=message.response_metadata,
|
||||
)
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Base class for chat models.
|
||||
|
||||
|
@ -8,9 +8,9 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
from typing import Literal, Union, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import TypedDict, overload
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
@ -19,6 +19,52 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.messages import content_blocks as types
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
|
||||
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
|
||||
|
||||
|
||||
def _convert_to_v1(message: BaseMessage) -> MessageV1:
|
||||
"""Best-effort conversion of a V0 AIMessage to V1."""
|
||||
if isinstance(message.content, str):
|
||||
content: list[types.ContentBlock] = []
|
||||
if message.content:
|
||||
content = [{"type": "text", "text": message.content}]
|
||||
else:
|
||||
content = []
|
||||
for block in message.content:
|
||||
if isinstance(block, str):
|
||||
content.append({"type": "text", "text": block})
|
||||
elif isinstance(block, dict):
|
||||
content.append(block)
|
||||
else:
|
||||
pass
|
||||
|
||||
if isinstance(message, HumanMessage):
|
||||
return HumanMessageV1(content=content)
|
||||
if isinstance(message, AIMessage):
|
||||
for tool_call in message.tool_calls:
|
||||
content.append(tool_call)
|
||||
return AIMessageV1(
|
||||
content=content,
|
||||
usage_metadata=message.usage_metadata,
|
||||
response_metadata=message.response_metadata,
|
||||
tool_calls=message.tool_calls,
|
||||
)
|
||||
if isinstance(message, SystemMessage):
|
||||
return SystemMessageV1(content=content)
|
||||
if isinstance(message, ToolMessage):
|
||||
return ToolMessageV1(
|
||||
tool_call_id=message.tool_call_id,
|
||||
content=content,
|
||||
artifact=message.artifact,
|
||||
)
|
||||
error_message = f"Unsupported message type: {type(message)}"
|
||||
raise TypeError(error_message)
|
||||
|
||||
|
||||
class PromptValue(Serializable, ABC):
|
||||
@ -75,6 +121,26 @@ class StringPromptValue(PromptValue):
|
||||
"""Return prompt as messages."""
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
@overload
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[list[BaseMessage], list[MessageV1]]:
|
||||
"""Return prompt as a list of messages.
|
||||
|
||||
Args:
|
||||
output_version: The output version, either "v0" (default) or "v1".
|
||||
"""
|
||||
if output_version == "v1":
|
||||
return [HumanMessageV1(content=self.text)]
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
"""Chat prompt value.
|
||||
@ -89,8 +155,24 @@ class ChatPromptValue(PromptValue):
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
"""Return prompt as a list of messages."""
|
||||
@overload
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[list[BaseMessage], list[MessageV1]]:
|
||||
"""Return prompt as a list of messages.
|
||||
|
||||
Args:
|
||||
output_version: The output version, either "v0" (default) or "v1".
|
||||
"""
|
||||
if output_version == "v1":
|
||||
return [_convert_to_v1(m) for m in self.messages]
|
||||
return list(self.messages)
|
||||
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user