mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 04:38:26 +00:00
wip
This commit is contained in:
parent
440c162b8b
commit
a1a4e9f4a5
@ -219,6 +219,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
- If False (default), will always use streaming case if available.
|
- If False (default), will always use streaming case if available.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
coerce_input: bool = False
|
||||||
|
""""""
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def raise_deprecation(cls, values: dict) -> Any:
|
def raise_deprecation(cls, values: dict) -> Any:
|
||||||
@ -260,11 +263,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
|
|
||||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||||
if isinstance(input, PromptValue):
|
if isinstance(input, PromptValue):
|
||||||
return input
|
prompt_val = input
|
||||||
elif isinstance(input, str):
|
elif isinstance(input, str):
|
||||||
return StringPromptValue(text=input)
|
prompt_val = StringPromptValue(text=input)
|
||||||
elif isinstance(input, Sequence):
|
elif isinstance(input, Sequence):
|
||||||
return ChatPromptValue(messages=convert_to_messages(input))
|
prompt_val = ChatPromptValue(messages=convert_to_messages(input))
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(input)}. "
|
||||||
@ -272,6 +275,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if self.coerce_input:
|
||||||
|
messages = prompt_val.to_messages()
|
||||||
|
return ChatPromptValue(messages=self._coerce_messages(messages))
|
||||||
|
else:
|
||||||
|
return prompt_val
|
||||||
|
|
||||||
|
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||||
|
return messages
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
|
@ -880,6 +880,7 @@ def convert_to_openai_messages(
|
|||||||
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
|
||||||
*,
|
*,
|
||||||
text_format: Literal["string", "block"] = "string",
|
text_format: Literal["string", "block"] = "string",
|
||||||
|
coerce: bool = False,
|
||||||
) -> Union[dict, list[dict]]:
|
) -> Union[dict, list[dict]]:
|
||||||
"""Convert LangChain messages into OpenAI message dicts.
|
"""Convert LangChain messages into OpenAI message dicts.
|
||||||
|
|
||||||
@ -992,8 +993,14 @@ def convert_to_openai_messages(
|
|||||||
f"but is missing expected key(s) "
|
f"but is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
raise ValueError(err)
|
if not coerce:
|
||||||
content.append({"type": block["type"], "text": block["text"]})
|
raise ValueError(err)
|
||||||
|
else:
|
||||||
|
logger.warning(err)
|
||||||
|
text = str({k: v for k, v in block.items() if k != "type"}) if len(block) > 1 else ""
|
||||||
|
else:
|
||||||
|
text = block["text"]
|
||||||
|
content.append({"type": block["type"], "text": text})
|
||||||
elif block.get("type") == "image_url":
|
elif block.get("type") == "image_url":
|
||||||
if missing := [k for k in ("image_url",) if k not in block]:
|
if missing := [k for k in ("image_url",) if k not in block]:
|
||||||
err = (
|
err = (
|
||||||
@ -1002,7 +1009,10 @@ def convert_to_openai_messages(
|
|||||||
f"but is missing expected key(s) "
|
f"but is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
raise ValueError(err)
|
if not coerce:
|
||||||
|
raise ValueError(err)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
content.append(
|
content.append(
|
||||||
{"type": "image_url", "image_url": block["image_url"]}
|
{"type": "image_url", "image_url": block["image_url"]}
|
||||||
)
|
)
|
||||||
@ -1021,7 +1031,10 @@ def convert_to_openai_messages(
|
|||||||
f"but 'source' is missing expected key(s) "
|
f"but 'source' is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
raise ValueError(err)
|
if not coerce:
|
||||||
|
raise ValueError(err)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
content.append(
|
content.append(
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@ -1044,7 +1057,10 @@ def convert_to_openai_messages(
|
|||||||
f"but 'image' is missing expected key(s) "
|
f"but 'image' is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
raise ValueError(err)
|
if not coerce:
|
||||||
|
raise ValueError(err)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
b64_image = _bytes_to_b64_str(image["source"]["bytes"])
|
b64_image = _bytes_to_b64_str(image["source"]["bytes"])
|
||||||
content.append(
|
content.append(
|
||||||
{
|
{
|
||||||
@ -1064,7 +1080,10 @@ def convert_to_openai_messages(
|
|||||||
f"but does not have a 'source' or 'image' key. Full "
|
f"but does not have a 'source' or 'image' key. Full "
|
||||||
f"content block:\n\n{block}"
|
f"content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
raise ValueError(err)
|
if not coerce:
|
||||||
|
raise ValueError(err)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
elif block.get("type") == "tool_use":
|
elif block.get("type") == "tool_use":
|
||||||
if missing := [
|
if missing := [
|
||||||
k for k in ("id", "name", "input") if k not in block
|
k for k in ("id", "name", "input") if k not in block
|
||||||
@ -1075,7 +1094,10 @@ def convert_to_openai_messages(
|
|||||||
f"'tool_use', but is missing expected key(s) "
|
f"'tool_use', but is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
raise ValueError(err)
|
if not coerce:
|
||||||
|
raise ValueError(err)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
if not any(
|
if not any(
|
||||||
tool_call["id"] == block["id"]
|
tool_call["id"] == block["id"]
|
||||||
for tool_call in cast(AIMessage, message).tool_calls
|
for tool_call in cast(AIMessage, message).tool_calls
|
||||||
@ -1101,7 +1123,10 @@ def convert_to_openai_messages(
|
|||||||
f"'tool_result', but is missing expected key(s) "
|
f"'tool_result', but is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
if not coerce:
|
||||||
|
raise ValueError(msg)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
tool_message = ToolMessage(
|
tool_message = ToolMessage(
|
||||||
block["content"],
|
block["content"],
|
||||||
tool_call_id=block["tool_use_id"],
|
tool_call_id=block["tool_use_id"],
|
||||||
@ -1121,7 +1146,10 @@ def convert_to_openai_messages(
|
|||||||
f"but does not have a 'json' key. Full "
|
f"but does not have a 'json' key. Full "
|
||||||
f"content block:\n\n{block}"
|
f"content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
if not coerce:
|
||||||
|
raise ValueError(msg)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
content.append(
|
content.append(
|
||||||
{"type": "text", "text": json.dumps(block["json"])}
|
{"type": "text", "text": json.dumps(block["json"])}
|
||||||
)
|
)
|
||||||
|
@ -61,7 +61,7 @@ from langchain_core.messages import (
|
|||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk, convert_to_openai_messages, convert_to_messages,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.ai import (
|
from langchain_core.messages.ai import (
|
||||||
InputTokenDetails,
|
InputTokenDetails,
|
||||||
@ -860,6 +860,9 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
None, self._create_chat_result, response, generation_info
|
None, self._create_chat_result, response, generation_info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _coerce_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||||
|
return convert_to_messages(convert_to_openai_messages(messages))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
|
Loading…
Reference in New Issue
Block a user