community[minor]: Add tools calls to ChatEdenAI (#22320)

### Description  
Add tools implementation to `ChatEdenAI`:
- `bind_tools()`
- `with_structured_output()`

### Documentation 
Updated `docs/docs/integrations/chat/edenai.ipynb`

### Notes
We don´t support stream with tools as of yet. If stream is called with
tools we directly yield the whole message from `generate` (implemented
the same way as Anthropic did).
This commit is contained in:
KyrianC
2024-06-04 19:29:28 +02:00
committed by GitHub
parent 9d4350e69a
commit 03178ee74f
4 changed files with 518 additions and 19 deletions

View File

@@ -2,9 +2,15 @@
from typing import List
import pytest
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_community.chat_models.edenai import (
_extract_edenai_tool_results_from_messages,
_format_edenai_messages,
_message_role,
)
@@ -22,6 +28,7 @@ from langchain_community.chat_models.edenai import (
"text": "Hello how are you today?",
"previous_history": [],
"chatbot_global_action": "Translate the text from English to French",
"tool_results": [],
},
)
],
@@ -38,3 +45,26 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str)
def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def]
role = _message_role(role)
assert role == role_response
def test_extract_edenai_tool_results_mixed_messages() -> None:
fake_other_msg = BaseMessage(content="content", type="other message")
messages = [
fake_other_msg,
ToolMessage(tool_call_id="id1", content="result1"),
fake_other_msg,
ToolMessage(tool_call_id="id2", content="result2"),
ToolMessage(tool_call_id="id3", content="result3"),
]
expected_tool_results = [
{"id": "id2", "result": "result2"},
{"id": "id3", "result": "result3"},
]
expected_other_messages = [
fake_other_msg,
ToolMessage(tool_call_id="id1", content="result1"),
fake_other_msg,
]
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
assert tool_results == expected_tool_results
assert other_messages == expected_other_messages