This commit is contained in:
Bagatur 2024-10-28 08:28:21 -07:00
parent 440c162b8b
commit a1a4e9f4a5
3 changed files with 56 additions and 13 deletions

View File

@ -219,6 +219,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
- If False (default), will always use streaming case if available.
"""
coerce_input: bool = False
""""""
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: dict) -> Any:
@ -260,11 +263,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
prompt_val = input
elif isinstance(input, str):
return StringPromptValue(text=input)
prompt_val = StringPromptValue(text=input)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
prompt_val = ChatPromptValue(messages=convert_to_messages(input))
else:
msg = (
f"Invalid input type {type(input)}. "
@ -272,6 +275,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
raise ValueError(msg)
if self.coerce_input:
messages = prompt_val.to_messages()
return ChatPromptValue(messages=self._coerce_messages(messages))
else:
return prompt_val
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
return messages
def invoke(
self,
input: LanguageModelInput,

View File

@ -880,6 +880,7 @@ def convert_to_openai_messages(
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
*,
text_format: Literal["string", "block"] = "string",
coerce: bool = False,
) -> Union[dict, list[dict]]:
"""Convert LangChain messages into OpenAI message dicts.
@ -992,8 +993,14 @@ def convert_to_openai_messages(
f"but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
if not coerce:
raise ValueError(err)
content.append({"type": block["type"], "text": block["text"]})
else:
logger.warning(err)
text = str({k: v for k, v in block.items() if k != "type"}) if len(block) > 1 else ""
else:
text = block["text"]
content.append({"type": block["type"], "text": text})
elif block.get("type") == "image_url":
if missing := [k for k in ("image_url",) if k not in block]:
err = (
@ -1002,7 +1009,10 @@ def convert_to_openai_messages(
f"but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
if not coerce:
raise ValueError(err)
else:
continue
content.append(
{"type": "image_url", "image_url": block["image_url"]}
)
@ -1021,7 +1031,10 @@ def convert_to_openai_messages(
f"but 'source' is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
if not coerce:
raise ValueError(err)
else:
continue
content.append(
{
"type": "image_url",
@ -1044,7 +1057,10 @@ def convert_to_openai_messages(
f"but 'image' is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
if not coerce:
raise ValueError(err)
else:
continue
b64_image = _bytes_to_b64_str(image["source"]["bytes"])
content.append(
{
@ -1064,7 +1080,10 @@ def convert_to_openai_messages(
f"but does not have a 'source' or 'image' key. Full "
f"content block:\n\n{block}"
)
if not coerce:
raise ValueError(err)
else:
continue
elif block.get("type") == "tool_use":
if missing := [
k for k in ("id", "name", "input") if k not in block
@ -1075,7 +1094,10 @@ def convert_to_openai_messages(
f"'tool_use', but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
if not coerce:
raise ValueError(err)
else:
continue
if not any(
tool_call["id"] == block["id"]
for tool_call in cast(AIMessage, message).tool_calls
@ -1101,7 +1123,10 @@ def convert_to_openai_messages(
f"'tool_result', but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
if not coerce:
raise ValueError(msg)
else:
continue
tool_message = ToolMessage(
block["content"],
tool_call_id=block["tool_use_id"],
@ -1121,7 +1146,10 @@ def convert_to_openai_messages(
f"but does not have a 'json' key. Full "
f"content block:\n\n{block}"
)
if not coerce:
raise ValueError(msg)
else:
continue
content.append(
{"type": "text", "text": json.dumps(block["json"])}
)

View File

@ -61,7 +61,7 @@ from langchain_core.messages import (
SystemMessageChunk,
ToolCall,
ToolMessage,
ToolMessageChunk,
ToolMessageChunk, convert_to_openai_messages, convert_to_messages,
)
from langchain_core.messages.ai import (
InputTokenDetails,
@ -860,6 +860,9 @@ class BaseChatOpenAI(BaseChatModel):
None, self._create_chat_result, response, generation_info
)
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
return convert_to_messages(convert_to_openai_messages(messages))
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""