Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
770031de82 Merge branch 'master' into bagatur/coerce_input 2024-10-28 13:09:19 -07:00
Bagatur
28b3598744 wip 2024-10-28 13:09:15 -07:00
Bagatur
a1a4e9f4a5 wip 2024-10-28 08:28:21 -07:00
3 changed files with 79 additions and 18 deletions

View File

@@ -219,6 +219,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
- If False (default), will always use streaming case if available.
"""
coerce_input: bool = False
""""""
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: dict) -> Any:
@@ -260,11 +263,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
prompt_val = input
elif isinstance(input, str):
return StringPromptValue(text=input)
prompt_val = StringPromptValue(text=input)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
prompt_val = ChatPromptValue(messages=convert_to_messages(input))
else:
msg = (
f"Invalid input type {type(input)}. "
@@ -272,6 +275,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
raise ValueError(msg)
if self.coerce_input:
messages = prompt_val.to_messages()
return ChatPromptValue(messages=self._coerce_messages(messages))
else:
return prompt_val
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
return messages
def invoke(
self,
input: LanguageModelInput,

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import base64
import inspect
import json
import logging
from collections.abc import Iterable, Sequence
from functools import partial
from typing import (
@@ -46,6 +47,9 @@ if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable
logger = logging.getLogger(__name__)
def _get_type(v: Any) -> str:
"""Get the type associated with the object for serialization purposes."""
if isinstance(v, dict) and "type" in v:
@@ -880,6 +884,7 @@ def convert_to_openai_messages(
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
*,
text_format: Literal["string", "block"] = "string",
coerce: bool = False,
) -> Union[dict, list[dict]]:
"""Convert LangChain messages into OpenAI message dicts.
@@ -992,8 +997,17 @@ def convert_to_openai_messages(
f"but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
raise ValueError(err)
content.append({"type": block["type"], "text": block["text"]})
_raise_or_warn(err, coerce)
text = (
_try_json_dumps(
{k: v for k, v in block.items() if k != "type"}
)
if len(block) > 1
else ""
)
else:
text = block["text"]
content.append({"type": block["type"], "text": text})
elif block.get("type") == "image_url":
if missing := [k for k in ("image_url",) if k not in block]:
err = (
@@ -1002,7 +1016,8 @@ def convert_to_openai_messages(
f"but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
raise ValueError(err)
_raise_or_warn(err, coerce)
continue
content.append(
{"type": "image_url", "image_url": block["image_url"]}
)
@@ -1021,7 +1036,8 @@ def convert_to_openai_messages(
f"but 'source' is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
raise ValueError(err)
_raise_or_warn(err, coerce)
continue
content.append(
{
"type": "image_url",
@@ -1044,7 +1060,8 @@ def convert_to_openai_messages(
f"but 'image' is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
raise ValueError(err)
_raise_or_warn(err, coerce)
continue
b64_image = _bytes_to_b64_str(image["source"]["bytes"])
content.append(
{
@@ -1064,7 +1081,8 @@ def convert_to_openai_messages(
f"but does not have a 'source' or 'image' key. Full "
f"content block:\n\n{block}"
)
raise ValueError(err)
_raise_or_warn(err, coerce)
continue
elif block.get("type") == "tool_use":
if missing := [
k for k in ("id", "name", "input") if k not in block
@@ -1075,7 +1093,8 @@ def convert_to_openai_messages(
f"'tool_use', but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
raise ValueError(err)
_raise_or_warn(err, coerce)
continue
if not any(
tool_call["id"] == block["id"]
for tool_call in cast(AIMessage, message).tool_calls
@@ -1101,7 +1120,8 @@ def convert_to_openai_messages(
f"'tool_result', but is missing expected key(s) "
f"{missing}. Full content block:\n\n{block}"
)
raise ValueError(msg)
_raise_or_warn(err, coerce)
continue
tool_message = ToolMessage(
block["content"],
tool_call_id=block["tool_use_id"],
@@ -1121,7 +1141,8 @@ def convert_to_openai_messages(
f"but does not have a 'json' key. Full "
f"content block:\n\n{block}"
)
raise ValueError(msg)
_raise_or_warn(msg, coerce)
continue
content.append(
{"type": "text", "text": json.dumps(block["json"])}
)
@@ -1139,8 +1160,16 @@ def convert_to_openai_messages(
f"messages[{i}].content[{j}]['guard_content']['text'] "
f"key. Full content block:\n\n{block}"
)
raise ValueError(msg)
text = block["guard_content"]["text"]
_raise_or_warn(msg, coerce)
text = (
_try_json_dumps(
{k: v for k, v in block.items() if k != "type"}
)
if len(block) > 1
else ""
)
else:
text = block["guard_content"]["text"]
if isinstance(text, dict):
text = text["text"]
content.append({"type": "text", "text": text})
@@ -1155,14 +1184,16 @@ def convert_to_openai_messages(
f"'media' but does not have key(s) {missing}. Full "
f"content block:\n\n{block}"
)
raise ValueError(err)
_raise_or_warn(err, coerce)
continue
if "image" not in block["mime_type"]:
err = (
f"OpenAI messages can only support text and image data."
f" Received content block with media of type:"
f" {block['mime_type']}"
)
raise ValueError(err)
_raise_or_warn(err, coerce)
continue
b64_image = _bytes_to_b64_str(block["data"])
content.append(
{
@@ -1181,7 +1212,8 @@ def convert_to_openai_messages(
f"Anthropic, Bedrock Converse, or VertexAI format. Full "
f"content block:\n\n{block}"
)
raise ValueError(err)
_raise_or_warn(err, coerce)
continue
if text_format == "string" and not any(
block["type"] != "text" for block in content
):
@@ -1404,3 +1436,17 @@ def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:
}
for tool_call in tool_calls
]
def _try_json_dumps(o: Any) -> str:
try:
return json.dumps(o)
except Exception:
return str(o)
def _raise_or_warn(msg: str, silence: bool) -> None:
if not silence:
raise ValueError(msg)
else:
logger.warning(msg)

View File

@@ -61,7 +61,7 @@ from langchain_core.messages import (
SystemMessageChunk,
ToolCall,
ToolMessage,
ToolMessageChunk,
ToolMessageChunk, convert_to_openai_messages, convert_to_messages,
)
from langchain_core.messages.ai import (
InputTokenDetails,
@@ -860,6 +860,9 @@ class BaseChatOpenAI(BaseChatModel):
None, self._create_chat_result, response, generation_info
)
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
return convert_to_messages(convert_to_openai_messages(messages, coerce=True))
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""