mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -49,6 +49,8 @@ from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
make_invalid_tool_call,
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
@@ -82,9 +84,31 @@ def _convert_mistral_chat_message_to_message(
|
||||
content = cast(str, _message["content"])
|
||||
|
||||
additional_kwargs: Dict = {}
|
||||
if tool_calls := _message.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = tool_calls
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
if raw_tool_calls := _message.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
parsed: dict = cast(
|
||||
dict, parse_tool_call(raw_tool_call, return_id=False)
|
||||
)
|
||||
tool_calls.append(
|
||||
{
|
||||
**parsed,
|
||||
**{"id": None},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(
|
||||
dict(make_invalid_tool_call(raw_tool_call, str(e)))
|
||||
)
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
async def _aiter_sse(
|
||||
@@ -133,9 +157,27 @@ def _convert_delta_to_message_chunk(
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
additional_kwargs: Dict = {}
|
||||
if tool_calls := _delta.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = tool_calls
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
if raw_tool_calls := _delta.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
{
|
||||
"name": rtc["function"].get("name"),
|
||||
"args": rtc["function"].get("arguments"),
|
||||
"id": rtc.get("id"),
|
||||
"index": rtc.get("index"),
|
||||
}
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
tool_call_chunks = []
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
@@ -163,7 +205,7 @@ def _convert_message_to_mistral_chat_message(
|
||||
for tc in message.additional_kwargs["tool_calls"]
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
tool_calls = []
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
|
||||
Reference in New Issue
Block a user