mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
wip
This commit is contained in:
parent
440c162b8b
commit
a1a4e9f4a5
@ -219,6 +219,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
- If False (default), will always use streaming case if available.
|
||||
"""
|
||||
|
||||
coerce_input: bool = False
|
||||
""""""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
@ -260,11 +263,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
return input
|
||||
prompt_val = input
|
||||
elif isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
prompt_val = StringPromptValue(text=input)
|
||||
elif isinstance(input, Sequence):
|
||||
return ChatPromptValue(messages=convert_to_messages(input))
|
||||
prompt_val = ChatPromptValue(messages=convert_to_messages(input))
|
||||
else:
|
||||
msg = (
|
||||
f"Invalid input type {type(input)}. "
|
||||
@ -272,6 +275,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.coerce_input:
|
||||
messages = prompt_val.to_messages()
|
||||
return ChatPromptValue(messages=self._coerce_messages(messages))
|
||||
else:
|
||||
return prompt_val
|
||||
|
||||
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
return messages
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
|
@ -880,6 +880,7 @@ def convert_to_openai_messages(
|
||||
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||
*,
|
||||
text_format: Literal["string", "block"] = "string",
|
||||
coerce: bool = False,
|
||||
) -> Union[dict, list[dict]]:
|
||||
"""Convert LangChain messages into OpenAI message dicts.
|
||||
|
||||
@ -992,8 +993,14 @@ def convert_to_openai_messages(
|
||||
f"but is missing expected key(s) "
|
||||
f"{missing}. Full content block:\n\n{block}"
|
||||
)
|
||||
if not coerce:
|
||||
raise ValueError(err)
|
||||
content.append({"type": block["type"], "text": block["text"]})
|
||||
else:
|
||||
logger.warning(err)
|
||||
text = str({k: v for k, v in block.items() if k != "type"}) if len(block) > 1 else ""
|
||||
else:
|
||||
text = block["text"]
|
||||
content.append({"type": block["type"], "text": text})
|
||||
elif block.get("type") == "image_url":
|
||||
if missing := [k for k in ("image_url",) if k not in block]:
|
||||
err = (
|
||||
@ -1002,7 +1009,10 @@ def convert_to_openai_messages(
|
||||
f"but is missing expected key(s) "
|
||||
f"{missing}. Full content block:\n\n{block}"
|
||||
)
|
||||
if not coerce:
|
||||
raise ValueError(err)
|
||||
else:
|
||||
continue
|
||||
content.append(
|
||||
{"type": "image_url", "image_url": block["image_url"]}
|
||||
)
|
||||
@ -1021,7 +1031,10 @@ def convert_to_openai_messages(
|
||||
f"but 'source' is missing expected key(s) "
|
||||
f"{missing}. Full content block:\n\n{block}"
|
||||
)
|
||||
if not coerce:
|
||||
raise ValueError(err)
|
||||
else:
|
||||
continue
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
@ -1044,7 +1057,10 @@ def convert_to_openai_messages(
|
||||
f"but 'image' is missing expected key(s) "
|
||||
f"{missing}. Full content block:\n\n{block}"
|
||||
)
|
||||
if not coerce:
|
||||
raise ValueError(err)
|
||||
else:
|
||||
continue
|
||||
b64_image = _bytes_to_b64_str(image["source"]["bytes"])
|
||||
content.append(
|
||||
{
|
||||
@ -1064,7 +1080,10 @@ def convert_to_openai_messages(
|
||||
f"but does not have a 'source' or 'image' key. Full "
|
||||
f"content block:\n\n{block}"
|
||||
)
|
||||
if not coerce:
|
||||
raise ValueError(err)
|
||||
else:
|
||||
continue
|
||||
elif block.get("type") == "tool_use":
|
||||
if missing := [
|
||||
k for k in ("id", "name", "input") if k not in block
|
||||
@ -1075,7 +1094,10 @@ def convert_to_openai_messages(
|
||||
f"'tool_use', but is missing expected key(s) "
|
||||
f"{missing}. Full content block:\n\n{block}"
|
||||
)
|
||||
if not coerce:
|
||||
raise ValueError(err)
|
||||
else:
|
||||
continue
|
||||
if not any(
|
||||
tool_call["id"] == block["id"]
|
||||
for tool_call in cast(AIMessage, message).tool_calls
|
||||
@ -1101,7 +1123,10 @@ def convert_to_openai_messages(
|
||||
f"'tool_result', but is missing expected key(s) "
|
||||
f"{missing}. Full content block:\n\n{block}"
|
||||
)
|
||||
if not coerce:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
continue
|
||||
tool_message = ToolMessage(
|
||||
block["content"],
|
||||
tool_call_id=block["tool_use_id"],
|
||||
@ -1121,7 +1146,10 @@ def convert_to_openai_messages(
|
||||
f"but does not have a 'json' key. Full "
|
||||
f"content block:\n\n{block}"
|
||||
)
|
||||
if not coerce:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
continue
|
||||
content.append(
|
||||
{"type": "text", "text": json.dumps(block["json"])}
|
||||
)
|
||||
|
@ -61,7 +61,7 @@ from langchain_core.messages import (
|
||||
SystemMessageChunk,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
ToolMessageChunk, convert_to_openai_messages, convert_to_messages,
|
||||
)
|
||||
from langchain_core.messages.ai import (
|
||||
InputTokenDetails,
|
||||
@ -860,6 +860,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
None, self._create_chat_result, response, generation_info
|
||||
)
|
||||
|
||||
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
return convert_to_messages(convert_to_openai_messages(messages))
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
Loading…
Reference in New Issue
Block a user