From a1a4e9f4a535e5d93277b6978d5289c4ede7d287 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 28 Oct 2024 08:28:21 -0700 Subject: [PATCH] wip --- .../language_models/chat_models.py | 18 ++++++-- libs/core/langchain_core/messages/utils.py | 46 +++++++++++++++---- .../langchain_openai/chat_models/base.py | 5 +- 3 files changed, 56 insertions(+), 13 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 39fd11c247f..1c6534b545a 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -219,6 +219,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): - If False (default), will always use streaming case if available. """ + coerce_input: bool = False + """""" + @model_validator(mode="before") @classmethod def raise_deprecation(cls, values: dict) -> Any: @@ -260,11 +263,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def _convert_input(self, input: LanguageModelInput) -> PromptValue: if isinstance(input, PromptValue): - return input + prompt_val = input elif isinstance(input, str): - return StringPromptValue(text=input) + prompt_val = StringPromptValue(text=input) elif isinstance(input, Sequence): - return ChatPromptValue(messages=convert_to_messages(input)) + prompt_val = ChatPromptValue(messages=convert_to_messages(input)) else: msg = ( f"Invalid input type {type(input)}. " @@ -272,6 +275,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) raise ValueError(msg) + if self.coerce_input: + messages = prompt_val.to_messages() + return ChatPromptValue(messages=self._coerce_messages(messages)) + else: + return prompt_val + + def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]: + return messages + def invoke( self, input: LanguageModelInput, diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 9bf617af17d..c53b951dc16 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -880,6 +880,7 @@ def convert_to_openai_messages( messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], *, text_format: Literal["string", "block"] = "string", + coerce: bool = False, ) -> Union[dict, list[dict]]: """Convert LangChain messages into OpenAI message dicts. @@ -992,8 +993,14 @@ def convert_to_openai_messages( f"but is missing expected key(s) " f"{missing}. Full content block:\n\n{block}" ) - raise ValueError(err) - content.append({"type": block["type"], "text": block["text"]}) + if not coerce: + raise ValueError(err) + else: + logger.warning(err) + text = str({k: v for k, v in block.items() if k != "type"}) if len(block) > 1 else "" + else: + text = block["text"] + content.append({"type": block["type"], "text": text}) elif block.get("type") == "image_url": if missing := [k for k in ("image_url",) if k not in block]: err = ( @@ -1002,7 +1009,10 @@ def convert_to_openai_messages( f"but is missing expected key(s) " f"{missing}. Full content block:\n\n{block}" ) - raise ValueError(err) + if not coerce: + raise ValueError(err) + else: + continue content.append( {"type": "image_url", "image_url": block["image_url"]} ) @@ -1021,7 +1031,10 @@ def convert_to_openai_messages( f"but 'source' is missing expected key(s) " f"{missing}. Full content block:\n\n{block}" ) - raise ValueError(err) + if not coerce: + raise ValueError(err) + else: + continue content.append( { "type": "image_url", @@ -1044,7 +1057,10 @@ def convert_to_openai_messages( f"but 'image' is missing expected key(s) " f"{missing}. Full content block:\n\n{block}" ) - raise ValueError(err) + if not coerce: + raise ValueError(err) + else: + continue b64_image = _bytes_to_b64_str(image["source"]["bytes"]) content.append( { @@ -1064,7 +1080,10 @@ def convert_to_openai_messages( f"but does not have a 'source' or 'image' key. Full " f"content block:\n\n{block}" ) - raise ValueError(err) + if not coerce: + raise ValueError(err) + else: + continue elif block.get("type") == "tool_use": if missing := [ k for k in ("id", "name", "input") if k not in block @@ -1075,7 +1094,10 @@ def convert_to_openai_messages( f"'tool_use', but is missing expected key(s) " f"{missing}. Full content block:\n\n{block}" ) - raise ValueError(err) + if not coerce: + raise ValueError(err) + else: + continue if not any( tool_call["id"] == block["id"] for tool_call in cast(AIMessage, message).tool_calls @@ -1101,7 +1123,10 @@ def convert_to_openai_messages( f"'tool_result', but is missing expected key(s) " f"{missing}. Full content block:\n\n{block}" ) - raise ValueError(msg) + if not coerce: + raise ValueError(msg) + else: + continue tool_message = ToolMessage( block["content"], tool_call_id=block["tool_use_id"], @@ -1121,7 +1146,10 @@ def convert_to_openai_messages( f"but does not have a 'json' key. Full " f"content block:\n\n{block}" ) - raise ValueError(msg) + if not coerce: + raise ValueError(msg) + else: + continue content.append( {"type": "text", "text": json.dumps(block["json"])} ) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 52cb1e1912d..f37458b257b 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -61,7 +61,7 @@ from langchain_core.messages import ( SystemMessageChunk, ToolCall, ToolMessage, - ToolMessageChunk, + ToolMessageChunk, convert_to_openai_messages, convert_to_messages, ) from langchain_core.messages.ai import ( InputTokenDetails, @@ -860,6 +860,9 @@ class BaseChatOpenAI(BaseChatModel): None, self._create_chat_result, response, generation_info ) + def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]: + return convert_to_messages(convert_to_openai_messages(messages)) + @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters."""