This commit is contained in:
Bagatur 2024-10-28 08:28:21 -07:00
parent 440c162b8b
commit a1a4e9f4a5
3 changed files with 56 additions and 13 deletions

View File

@ -219,6 +219,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
- If False (default), will always use streaming case if available. - If False (default), will always use streaming case if available.
""" """
coerce_input: bool = False
""""""
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
@ -260,11 +263,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _convert_input(self, input: LanguageModelInput) -> PromptValue: def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue): if isinstance(input, PromptValue):
return input prompt_val = input
elif isinstance(input, str): elif isinstance(input, str):
return StringPromptValue(text=input) prompt_val = StringPromptValue(text=input)
elif isinstance(input, Sequence): elif isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input)) prompt_val = ChatPromptValue(messages=convert_to_messages(input))
else: else:
msg = ( msg = (
f"Invalid input type {type(input)}. " f"Invalid input type {type(input)}. "
@ -272,6 +275,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) )
raise ValueError(msg) raise ValueError(msg)
if self.coerce_input:
messages = prompt_val.to_messages()
return ChatPromptValue(messages=self._coerce_messages(messages))
else:
return prompt_val
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
return messages
def invoke( def invoke(
self, self,
input: LanguageModelInput, input: LanguageModelInput,

View File

@ -880,6 +880,7 @@ def convert_to_openai_messages(
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
*, *,
text_format: Literal["string", "block"] = "string", text_format: Literal["string", "block"] = "string",
coerce: bool = False,
) -> Union[dict, list[dict]]: ) -> Union[dict, list[dict]]:
"""Convert LangChain messages into OpenAI message dicts. """Convert LangChain messages into OpenAI message dicts.
@ -992,8 +993,14 @@ def convert_to_openai_messages(
f"but is missing expected key(s) " f"but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}" f"{missing}. Full content block:\n\n{block}"
) )
if not coerce:
raise ValueError(err) raise ValueError(err)
content.append({"type": block["type"], "text": block["text"]}) else:
logger.warning(err)
text = str({k: v for k, v in block.items() if k != "type"}) if len(block) > 1 else ""
else:
text = block["text"]
content.append({"type": block["type"], "text": text})
elif block.get("type") == "image_url": elif block.get("type") == "image_url":
if missing := [k for k in ("image_url",) if k not in block]: if missing := [k for k in ("image_url",) if k not in block]:
err = ( err = (
@ -1002,7 +1009,10 @@ def convert_to_openai_messages(
f"but is missing expected key(s) " f"but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}" f"{missing}. Full content block:\n\n{block}"
) )
if not coerce:
raise ValueError(err) raise ValueError(err)
else:
continue
content.append( content.append(
{"type": "image_url", "image_url": block["image_url"]} {"type": "image_url", "image_url": block["image_url"]}
) )
@ -1021,7 +1031,10 @@ def convert_to_openai_messages(
f"but 'source' is missing expected key(s) " f"but 'source' is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}" f"{missing}. Full content block:\n\n{block}"
) )
if not coerce:
raise ValueError(err) raise ValueError(err)
else:
continue
content.append( content.append(
{ {
"type": "image_url", "type": "image_url",
@ -1044,7 +1057,10 @@ def convert_to_openai_messages(
f"but 'image' is missing expected key(s) " f"but 'image' is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}" f"{missing}. Full content block:\n\n{block}"
) )
if not coerce:
raise ValueError(err) raise ValueError(err)
else:
continue
b64_image = _bytes_to_b64_str(image["source"]["bytes"]) b64_image = _bytes_to_b64_str(image["source"]["bytes"])
content.append( content.append(
{ {
@ -1064,7 +1080,10 @@ def convert_to_openai_messages(
f"but does not have a 'source' or 'image' key. Full " f"but does not have a 'source' or 'image' key. Full "
f"content block:\n\n{block}" f"content block:\n\n{block}"
) )
if not coerce:
raise ValueError(err) raise ValueError(err)
else:
continue
elif block.get("type") == "tool_use": elif block.get("type") == "tool_use":
if missing := [ if missing := [
k for k in ("id", "name", "input") if k not in block k for k in ("id", "name", "input") if k not in block
@ -1075,7 +1094,10 @@ def convert_to_openai_messages(
f"'tool_use', but is missing expected key(s) " f"'tool_use', but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}" f"{missing}. Full content block:\n\n{block}"
) )
if not coerce:
raise ValueError(err) raise ValueError(err)
else:
continue
if not any( if not any(
tool_call["id"] == block["id"] tool_call["id"] == block["id"]
for tool_call in cast(AIMessage, message).tool_calls for tool_call in cast(AIMessage, message).tool_calls
@ -1101,7 +1123,10 @@ def convert_to_openai_messages(
f"'tool_result', but is missing expected key(s) " f"'tool_result', but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}" f"{missing}. Full content block:\n\n{block}"
) )
if not coerce:
raise ValueError(msg) raise ValueError(msg)
else:
continue
tool_message = ToolMessage( tool_message = ToolMessage(
block["content"], block["content"],
tool_call_id=block["tool_use_id"], tool_call_id=block["tool_use_id"],
@ -1121,7 +1146,10 @@ def convert_to_openai_messages(
f"but does not have a 'json' key. Full " f"but does not have a 'json' key. Full "
f"content block:\n\n{block}" f"content block:\n\n{block}"
) )
if not coerce:
raise ValueError(msg) raise ValueError(msg)
else:
continue
content.append( content.append(
{"type": "text", "text": json.dumps(block["json"])} {"type": "text", "text": json.dumps(block["json"])}
) )

View File

@ -61,7 +61,7 @@ from langchain_core.messages import (
SystemMessageChunk, SystemMessageChunk,
ToolCall, ToolCall,
ToolMessage, ToolMessage,
ToolMessageChunk, ToolMessageChunk, convert_to_openai_messages, convert_to_messages,
) )
from langchain_core.messages.ai import ( from langchain_core.messages.ai import (
InputTokenDetails, InputTokenDetails,
@ -860,6 +860,9 @@ class BaseChatOpenAI(BaseChatModel):
None, self._create_chat_result, response, generation_info None, self._create_chat_result, response, generation_info
) )
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
return convert_to_messages(convert_to_openai_messages(messages))
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""