community[minor]: Add tools calls to ChatEdenAI (#22320)

### Description  
Add tools implementation to `ChatEdenAI`:
- `bind_tools()`
- `with_structured_output()`

### Documentation 
Updated `docs/docs/integrations/chat/edenai.ipynb`

### Notes
We don´t support stream with tools as of yet. If stream is called with
tools we directly yield the whole message from `generate` (implemented
the same way as Anthropic did).
This commit is contained in:
KyrianC
2024-06-04 19:29:28 +02:00
committed by GitHub
parent 9d4350e69a
commit 03178ee74f
4 changed files with 518 additions and 19 deletions

View File

@@ -1,11 +1,28 @@
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
import warnings
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
@@ -15,16 +32,62 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_community.utilities.requests import Requests
def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationChunk:
message = generated_result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
ToolCallChunk(
name=tool_call["name"],
args=json.dumps(tool_call["args"]),
id=tool_call["id"],
index=idx,
)
for idx, tool_call in enumerate(message.tool_calls)
]
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks,
)
return ChatGenerationChunk(message=message_chunk)
else:
return cast(ChatGenerationChunk, generated_result.generations[0])
def _message_role(type: str) -> str:
role_mapping = {"ai": "assistant", "human": "user", "chat": "user"}
role_mapping = {
"ai": "assistant",
"human": "user",
"chat": "user",
"AIMessageChunk": "assistant",
}
if type in role_mapping:
return role_mapping[type]
@@ -32,29 +95,120 @@ def _message_role(type: str) -> str:
raise ValueError(f"Unknown type: {type}")
def _extract_edenai_tool_results_from_messages(
messages: List[BaseMessage],
) -> Tuple[List[Dict[str, Any]], List[BaseMessage]]:
"""
Get the last langchain tools messages to transform them into edenai tool_results
Returns tool_results and messages without the extracted tool messages
"""
tool_results: List[Dict[str, Any]] = []
other_messages = messages[:]
for msg in reversed(messages):
if isinstance(msg, ToolMessage):
tool_results = [
{"id": msg.tool_call_id, "result": msg.content},
*tool_results,
]
other_messages.pop()
else:
break
return tool_results, other_messages
def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
system = None
formatted_messages = []
text = messages[-1].content
for i, message in enumerate(messages[:-1]):
if message.type == "system":
human_messages = filter(lambda msg: isinstance(msg, HumanMessage), messages)
last_human_message = list(human_messages)[-1] if human_messages else ""
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
for i, message in enumerate(other_messages):
if isinstance(message, SystemMessage):
if i != 0:
raise ValueError("System message must be at beginning of message list.")
system = message.content
else:
elif isinstance(message, ToolMessage):
formatted_messages.append({"role": "tool", "message": message.content})
elif message != last_human_message:
formatted_messages.append(
{
"role": _message_role(message.type),
"message": message.content,
"tool_calls": _format_tool_calls_to_edenai_tool_calls(message),
}
)
return {
"text": text,
"text": getattr(last_human_message, "content", ""),
"previous_history": formatted_messages,
"chatbot_global_action": system,
"tool_results": tool_results,
}
def _format_tool_calls_to_edenai_tool_calls(message: BaseMessage) -> List:
tool_calls = getattr(message, "tool_calls", [])
invalid_tool_calls = getattr(message, "invalid_tool_calls", [])
edenai_tool_calls = []
for invalid_tool_call in invalid_tool_calls:
edenai_tool_calls.append(
{
"arguments": invalid_tool_call.get("args"),
"id": invalid_tool_call.get("id"),
"name": invalid_tool_call.get("name"),
}
)
for tool_call in tool_calls:
tool_args = tool_call.get("args", {})
try:
arguments = json.dumps(tool_args)
except TypeError:
arguments = str(tool_args)
edenai_tool_calls.append(
{
"arguments": arguments,
"id": tool_call["id"],
"name": tool_call["name"],
}
)
return edenai_tool_calls
def _extract_tool_calls_from_edenai_response(
provider_response: Dict[str, Any],
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
tool_calls = []
invalid_tool_calls = []
message = provider_response.get("message", {})[1]
if raw_tool_calls := message.get("tool_calls"):
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(
ToolCall(
name=raw_tool_call["name"],
args=json.loads(raw_tool_call["arguments"]),
id=raw_tool_call["id"],
)
)
except json.JSONDecodeError as exc:
invalid_tool_calls.append(
InvalidToolCall(
name=raw_tool_call.get("name"),
args=raw_tool_call.get("arguments"),
id=raw_tool_call.get("id"),
error=f"Received JSONDecodeError {exc}",
)
)
return tool_calls, invalid_tool_calls
class ChatEdenAI(BaseChatModel):
"""`EdenAI` chat large language models.
@@ -179,6 +333,11 @@ class ChatEdenAI(BaseChatModel):
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Call out to EdenAI's chat endpoint."""
if "available_tools" in kwargs:
yield self._stream_with_tools_as_generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self._api_key}",
@@ -218,6 +377,11 @@ class ChatEdenAI(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
if "available_tools" in kwargs:
yield await self._astream_with_tools_as_agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self._api_key}",
@@ -253,6 +417,53 @@ class ChatEdenAI(BaseChatModel):
)
yield cg_chunk
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools]
formatted_tool_choice = "required" if tool_choice == "any" else tool_choice
return super().bind(
available_tools=formatted_tools, tool_choice=formatted_tool_choice, **kwargs
)
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
llm = self.bind_tools([schema], tool_choice="required")
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
def _generate(
self,
messages: List[BaseMessage],
@@ -262,10 +473,15 @@ class ChatEdenAI(BaseChatModel):
) -> ChatResult:
"""Call out to EdenAI's chat endpoint."""
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
if "available_tools" in kwargs:
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
url = f"{self.edenai_api_url}/text/chat"
headers = {
@@ -273,6 +489,7 @@ class ChatEdenAI(BaseChatModel):
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
payload: Dict[str, Any] = {
"providers": self.provider,
"max_tokens": self.max_tokens,
@@ -303,10 +520,18 @@ class ChatEdenAI(BaseChatModel):
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
tool_calls, invalid_tool_calls = _extract_tool_calls_from_edenai_response(
provider_response
)
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(content=provider_response["generated_text"])
message=AIMessage(
content=provider_response["generated_text"] or "",
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
)
],
llm_output=data,
@@ -320,10 +545,15 @@ class ChatEdenAI(BaseChatModel):
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
if "available_tools" in kwargs:
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
url = f"{self.edenai_api_url}/text/chat"
headers = {
@@ -370,3 +600,27 @@ class ChatEdenAI(BaseChatModel):
],
llm_output=data,
)
def _stream_with_tools_as_generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]],
run_manager: Optional[CallbackManagerForLLMRun],
**kwargs: Any,
) -> ChatGenerationChunk:
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
result = self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
return _result_to_chunked_message(result)
async def _astream_with_tools_as_agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]],
run_manager: Optional[AsyncCallbackManagerForLLMRun],
**kwargs: Any,
) -> ChatGenerationChunk:
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return _result_to_chunked_message(result)