core[minor], ...: add tool calls message (#18947)

core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]

```python
class ToolCall(TypedDict):
    name: str
    args: Dict[str, Any]
    id: Optional[str]

class InvalidToolCall(TypedDict):
    name: Optional[str]
    args: Optional[str]
    id: Optional[str]
    error: Optional[str]

class ToolCallChunk(TypedDict):
    name: Optional[str]
    args: Optional[str]
    id: Optional[str]
    index: Optional[int]


class AIMessage(BaseMessage):
    ...
    tool_calls: List[ToolCall] = []
    invalid_tool_calls: List[InvalidToolCall] = []
    ...


class AIMessageChunk(AIMessage, BaseMessageChunk):
    ...
    tool_call_chunks: Optional[List[ToolCallChunk]] = None
    ...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
  - additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).

Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Bagatur
2024-04-09 18:41:42 -05:00
committed by GitHub
parent 00552918ac
commit 9514bc4d67
31 changed files with 2347 additions and 389 deletions

View File

@@ -49,6 +49,8 @@ from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
@@ -82,9 +84,31 @@ def _convert_mistral_chat_message_to_message(
content = cast(str, _message["content"])
additional_kwargs: Dict = {}
if tool_calls := _message.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _message.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
parsed: dict = cast(
dict, parse_tool_call(raw_tool_call, return_id=False)
)
tool_calls.append(
{
**parsed,
**{"id": None},
},
)
except Exception as e:
invalid_tool_calls.append(
dict(make_invalid_tool_call(raw_tool_call, str(e)))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
async def _aiter_sse(
@@ -133,9 +157,27 @@ def _convert_delta_to_message_chunk(
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: Dict = {}
if tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
for rtc in raw_tool_calls
]
except KeyError:
pass
else:
tool_call_chunks = []
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
@@ -163,7 +205,7 @@ def _convert_message_to_mistral_chat_message(
for tc in message.additional_kwargs["tool_calls"]
]
else:
tool_calls = None
tool_calls = []
return {
"role": "assistant",
"content": message.content,