mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
@@ -54,7 +55,7 @@ from langchain_core.utils import (
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_anthropic.output_parsers import ToolsOutputParser
|
||||
from langchain_anthropic.output_parsers import ToolsOutputParser, extract_tool_calls
|
||||
|
||||
_message_type_lookups = {
|
||||
"human": "user",
|
||||
@@ -347,7 +348,24 @@ class ChatAnthropic(BaseChatModel):
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
message = result.generations[0].message
|
||||
if isinstance(message, AIMessage) and message.tool_calls is not None:
|
||||
tool_call_chunks = [
|
||||
{
|
||||
"name": tool_call["name"],
|
||||
"args": json.dumps(tool_call["args"]),
|
||||
"id": tool_call["id"],
|
||||
"index": idx,
|
||||
}
|
||||
for idx, tool_call in enumerate(message.tool_calls)
|
||||
]
|
||||
message_chunk = AIMessageChunk(
|
||||
content=message.content,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
yield ChatGenerationChunk(message=message_chunk)
|
||||
else:
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
with self._client.messages.stream(**params) as stream:
|
||||
for text in stream.text_stream:
|
||||
@@ -369,7 +387,24 @@ class ChatAnthropic(BaseChatModel):
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
message = result.generations[0].message
|
||||
if isinstance(message, AIMessage) and message.tool_calls is not None:
|
||||
tool_call_chunks = [
|
||||
{
|
||||
"name": tool_call["name"],
|
||||
"args": json.dumps(tool_call["args"]),
|
||||
"id": tool_call["id"],
|
||||
"index": idx,
|
||||
}
|
||||
for idx, tool_call in enumerate(message.tool_calls)
|
||||
]
|
||||
message_chunk = AIMessageChunk(
|
||||
content=message.content,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
yield ChatGenerationChunk(message=message_chunk)
|
||||
else:
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
async with self._async_client.messages.stream(**params) as stream:
|
||||
async for text in stream.text_stream:
|
||||
@@ -386,6 +421,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
}
|
||||
if len(content) == 1 and content[0]["type"] == "text":
|
||||
msg = AIMessage(content=content[0]["text"])
|
||||
elif any(block["type"] == "tool_use" for block in content):
|
||||
tool_calls = extract_tool_calls(content)
|
||||
msg = AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(content=content)
|
||||
return ChatResult(
|
||||
|
||||
@@ -1,18 +1,11 @@
|
||||
from typing import Any, List, Optional, Type, TypedDict, cast
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class _ToolCall(TypedDict):
|
||||
name: str
|
||||
args: dict
|
||||
id: str
|
||||
index: int
|
||||
|
||||
|
||||
class ToolsOutputParser(BaseGenerationOutputParser):
|
||||
first_tool_only: bool = False
|
||||
args_only: bool = False
|
||||
@@ -33,7 +26,19 @@ class ToolsOutputParser(BaseGenerationOutputParser):
|
||||
"""
|
||||
if not result or not isinstance(result[0], ChatGeneration):
|
||||
return None if self.first_tool_only else []
|
||||
tool_calls: List = _extract_tool_calls(result[0].message)
|
||||
message = result[0].message
|
||||
if isinstance(message.content, str):
|
||||
tool_calls: List = []
|
||||
else:
|
||||
content: List = message.content
|
||||
_tool_calls = [dict(tc) for tc in extract_tool_calls(content)]
|
||||
# Map tool call id to index
|
||||
id_to_index = {
|
||||
block["id"]: i
|
||||
for i, block in enumerate(content)
|
||||
if block["type"] == "tool_use"
|
||||
}
|
||||
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
|
||||
if self.pydantic_schemas:
|
||||
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
|
||||
elif self.args_only:
|
||||
@@ -44,23 +49,21 @@ class ToolsOutputParser(BaseGenerationOutputParser):
|
||||
if self.first_tool_only:
|
||||
return tool_calls[0] if tool_calls else None
|
||||
else:
|
||||
return tool_calls
|
||||
return [tool_call for tool_call in tool_calls]
|
||||
|
||||
def _pydantic_parse(self, tool_call: _ToolCall) -> BaseModel:
|
||||
def _pydantic_parse(self, tool_call: dict) -> BaseModel:
|
||||
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
|
||||
tool_call["name"]
|
||||
]
|
||||
return cls_(**tool_call["args"])
|
||||
|
||||
|
||||
def _extract_tool_calls(msg: BaseMessage) -> List[_ToolCall]:
|
||||
if isinstance(msg.content, str):
|
||||
return []
|
||||
def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
|
||||
tool_calls = []
|
||||
for i, block in enumerate(cast(List[dict], msg.content)):
|
||||
for block in content:
|
||||
if block["type"] != "tool_use":
|
||||
continue
|
||||
tool_calls.append(
|
||||
_ToolCall(name=block["name"], args=block["input"], id=block["id"], index=i)
|
||||
ToolCall(name=block["name"], args=block["input"], id=block["id"])
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
Reference in New Issue
Block a user