Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
a3d8113a8b fmt 2024-04-19 18:46:18 -07:00
Bagatur
a691bc51bb rfc: standardize input messages 2024-04-19 18:40:32 -07:00
2 changed files with 92 additions and 2 deletions

View File

@@ -38,7 +38,11 @@ from langchain_core.messages import (
AnyMessage,
BaseMessage,
BaseMessageChunk,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
convert_to_messages,
message_chunk_to_message,
)
@@ -131,13 +135,25 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Get the output type for this runnable."""
return AnyMessage
@property
def _standardize_input_messages(self) -> bool:
return False
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
if isinstance(input, StringPromptValue):
return input
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, ChatPromptValue):
messages = input.messages
if self._standardize_input_messages:
messages = _standardize_messages(messages)
return ChatPromptValue(messages=messages)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
messages = convert_to_messages(input)
if self._standardize_input_messages:
messages = _standardize_messages(messages)
return ChatPromptValue(messages=messages)
else:
raise ValueError(
f"Invalid input type {type(input)}. "
@@ -961,3 +977,73 @@ def _gen_info_and_msg_metadata(
**(generation.generation_info or {}),
**generation.message.response_metadata,
}
def _standardize_messages(
messages: Sequence[BaseMessage],
) -> List[Union[SystemMessage, HumanMessage, AIMessage, ToolMessage]]:
"""Convert sequence of messages to a canonical form.
- Convert function message to tool message.
- Convert tool_result content blocks to tool messages.
- Convert tool_use content blocks to AIMessage.tool_calls.
"""
standardized = []
for msg in messages:
if isinstance(msg, FunctionMessage):
standardized.append(
ToolMessage(
msg.content, tool_call_id=msg.name + "_" + str(uuid.uuid4())
)
)
elif isinstance(msg.content, list):
msg_blocks = []
for block in msg.content:
if block.get("type") == "tool_result":
if msg_blocks:
standardized.append(
msg.copy(
update={"content": msg_blocks, "id": str(uuid.uuid4())}
)
)
standardized.append(
ToolMessage(
block.get("content", ""),
tool_call_id=block.get("tool_use_id", str(uuid.uuid4())),
id=str(uuid.uuid4()),
)
)
msg_blocks = []
elif block.get("type") == "tool_use":
if msg_blocks:
standardized.append(
msg.copy(
update={"content": msg_blocks, "id": str(uuid.uuid4())}
)
)
tool_call = ToolCall(
name=block.get("name", ""),
args=block.get("input", ""),
id=block.get("id", str(uuid.uuid4())),
)
if standardized and isinstance(standardized[-1], AIMessage):
standardized[-1].tool_calls.append(tool_call)
else:
standardized.append(
AIMessage("", tool_calls=[tool_call], id=str(uuid.uuid4()))
)
msg_blocks = []
else:
msg_blocks.append(block)
if msg_blocks and msg_blocks != msg.content:
standardized.append(
msg.copy(update={"content": msg_blocks, "id": str(uuid.uuid4())})
)
elif msg_blocks:
standardized.append(msg.copy(deep=True))
else:
pass
else:
standardized.append(msg.copy(deep=True))
return standardized

View File

@@ -338,6 +338,10 @@ class ChatOpenAI(BaseChatModel):
"""Return whether this model can be serialized by Langchain."""
return True
@property
def _standardize_input_messages(self) -> bool:
return True
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")