diff --git a/docs/docs/integrations/chat/edenai.ipynb b/docs/docs/integrations/chat/edenai.ipynb index 4837fa6fefe..2c8f96aed2c 100644 --- a/docs/docs/integrations/chat/edenai.ipynb +++ b/docs/docs/integrations/chat/edenai.ipynb @@ -246,11 +246,220 @@ "source": [ "chain.invoke({\"product\": \"healthy snacks\"})" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tools\n", + "\n", + "### bind_tools()\n", + "\n", + "With `ChatEdenAI.bind_tools`, we can easily pass in Pydantic classes, dict schemas, LangChain tools, or even functions as tools to the model." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "llm = ChatEdenAI(provider=\"openai\", temperature=0.2, max_tokens=500)\n", + "\n", + "\n", + "class GetWeather(BaseModel):\n", + " \"\"\"Get the current weather in a given location\"\"\"\n", + "\n", + " location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n", + "\n", + "\n", + "llm_with_tools = llm.bind_tools([GetWeather])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', response_metadata={'openai': {'status': 'success', 'generated_text': None, 'message': [{'role': 'user', 'message': 'what is the weather like in San Francisco', 'tools': [{'name': 'GetWeather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'description': 'The city and state, e.g. San Francisco, CA', 'type': 'string'}}, 'required': ['location']}}], 'tool_calls': None}, {'role': 'assistant', 'message': None, 'tools': None, 'tool_calls': [{'id': 'call_tRpAO7KbQwgTjlka70mCQJdo', 'name': 'GetWeather', 'arguments': '{\"location\":\"San Francisco\"}'}]}], 'cost': 0.000194}}, id='run-5c44c01a-d7bb-4df6-835e-bda596080399-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco'}, 'id': 'call_tRpAO7KbQwgTjlka70mCQJdo'}])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ai_msg = llm_with_tools.invoke(\n", + " \"what is the weather like in San Francisco\",\n", + ")\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'GetWeather',\n", + " 'args': {'location': 'San Francisco'},\n", + " 'id': 'call_tRpAO7KbQwgTjlka70mCQJdo'}]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ai_msg.tool_calls" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### with_structured_output()\n", + "\n", + "The BaseChatModel.with_structured_output interface makes it easy to get structured output from chat models. You can use ChatEdenAI.with_structured_output, which uses tool-calling under the hood), to get the model to more reliably return an output in a specific format:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GetWeather(location='San Francisco')" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "structured_llm = llm.with_structured_output(GetWeather)\n", + "structured_llm.invoke(\n", + " \"what is the weather like in San Francisco\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Passing Tool Results to model\n", + "\n", + "Here is a full example of how to use a tool. Pass the tool output to the model, and get the result back from the model" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'11 + 11 = 22'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.messages import HumanMessage, ToolMessage\n", + "from langchain_core.tools import tool\n", + "\n", + "\n", + "@tool\n", + "def add(a: int, b: int) -> int:\n", + " \"\"\"Adds a and b.\n", + "\n", + " Args:\n", + " a: first int\n", + " b: second int\n", + " \"\"\"\n", + " return a + b\n", + "\n", + "\n", + "llm = ChatEdenAI(\n", + " provider=\"openai\",\n", + " max_tokens=1000,\n", + " temperature=0.2,\n", + ")\n", + "\n", + "llm_with_tools = llm.bind_tools([add], tool_choice=\"required\")\n", + "\n", + "query = \"What is 11 + 11?\"\n", + "\n", + "messages = [HumanMessage(query)]\n", + "ai_msg = llm_with_tools.invoke(messages)\n", + "messages.append(ai_msg)\n", + "\n", + "tool_call = ai_msg.tool_calls[0]\n", + "tool_output = add.invoke(tool_call[\"args\"])\n", + "\n", + "# This append the result from our tool to the model\n", + "messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n", + "\n", + "llm_with_tools.invoke(messages).content" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Streaming\n", + "\n", + "Eden AI does not currently support streaming tool calls. Attempting to stream will yield a single final message." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/eden/Projects/edenai-langchain/libs/community/langchain_community/chat_models/edenai.py:603: UserWarning: stream: Tool use is not yet supported in streaming mode.\n", + " warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n" + ] + }, + { + "data": { + "text/plain": [ + "[AIMessageChunk(content='', id='run-fae32908-ec48-4ab2-ad96-bb0d0511754f', tool_calls=[{'name': 'add', 'args': {'a': 9, 'b': 9}, 'id': 'call_n0Tm7I9zERWa6UpxCAVCweLN'}], tool_call_chunks=[{'name': 'add', 'args': '{\"a\": 9, \"b\": 9}', 'id': 'call_n0Tm7I9zERWa6UpxCAVCweLN', 'index': 0}])]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(llm_with_tools.stream(\"What's 9 + 9\"))" + ] } ], "metadata": { "kernelspec": { - "display_name": "langchain-pr", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/docs/scripts/model_feat_table.py b/docs/scripts/model_feat_table.py index 3b047e55097..59d96f98753 100644 --- a/docs/scripts/model_feat_table.py +++ b/docs/scripts/model_feat_table.py @@ -96,6 +96,12 @@ CHAT_MODEL_FEAT_TABLE = { "package": "langchain-community", "link": "/docs/integrations/chat/vllm/", }, + "ChatEdenAI": { + "tool_calling": True, + "structured_output": True, + "package": "langchain-community", + "link": "/docs/integrations/chat/edenai/", + }, } diff --git a/libs/community/langchain_community/chat_models/edenai.py b/libs/community/langchain_community/chat_models/edenai.py index a4252b804c6..28dcb2347c5 100644 --- a/libs/community/langchain_community/chat_models/edenai.py +++ b/libs/community/langchain_community/chat_models/edenai.py @@ -1,11 +1,28 @@ import json -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +import warnings +from operator import itemgetter +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) from aiohttp import ClientSession from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -15,16 +32,62 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, + HumanMessage, + InvalidToolCall, + SystemMessage, + ToolCall, + ToolCallChunk, + ToolMessage, +) +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator +from langchain_core.pydantic_v1 import ( + BaseModel, + Extra, + Field, + SecretStr, + root_validator, +) +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough +from langchain_core.tools import BaseTool from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_community.utilities.requests import Requests +def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationChunk: + message = generated_result.generations[0].message + if isinstance(message, AIMessage) and message.tool_calls is not None: + tool_call_chunks = [ + ToolCallChunk( + name=tool_call["name"], + args=json.dumps(tool_call["args"]), + id=tool_call["id"], + index=idx, + ) + for idx, tool_call in enumerate(message.tool_calls) + ] + message_chunk = AIMessageChunk( + content=message.content, + tool_call_chunks=tool_call_chunks, + ) + return ChatGenerationChunk(message=message_chunk) + else: + return cast(ChatGenerationChunk, generated_result.generations[0]) + + def _message_role(type: str) -> str: - role_mapping = {"ai": "assistant", "human": "user", "chat": "user"} + role_mapping = { + "ai": "assistant", + "human": "user", + "chat": "user", + "AIMessageChunk": "assistant", + } if type in role_mapping: return role_mapping[type] @@ -32,29 +95,120 @@ def _message_role(type: str) -> str: raise ValueError(f"Unknown type: {type}") +def _extract_edenai_tool_results_from_messages( + messages: List[BaseMessage], +) -> Tuple[List[Dict[str, Any]], List[BaseMessage]]: + """ + Get the last langchain tools messages to transform them into edenai tool_results + Returns tool_results and messages without the extracted tool messages + """ + tool_results: List[Dict[str, Any]] = [] + other_messages = messages[:] + for msg in reversed(messages): + if isinstance(msg, ToolMessage): + tool_results = [ + {"id": msg.tool_call_id, "result": msg.content}, + *tool_results, + ] + other_messages.pop() + else: + break + return tool_results, other_messages + + def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]: system = None formatted_messages = [] - text = messages[-1].content - for i, message in enumerate(messages[:-1]): - if message.type == "system": + + human_messages = filter(lambda msg: isinstance(msg, HumanMessage), messages) + last_human_message = list(human_messages)[-1] if human_messages else "" + + tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages) + for i, message in enumerate(other_messages): + if isinstance(message, SystemMessage): if i != 0: raise ValueError("System message must be at beginning of message list.") system = message.content - else: + elif isinstance(message, ToolMessage): + formatted_messages.append({"role": "tool", "message": message.content}) + elif message != last_human_message: formatted_messages.append( { "role": _message_role(message.type), "message": message.content, + "tool_calls": _format_tool_calls_to_edenai_tool_calls(message), } ) + return { - "text": text, + "text": getattr(last_human_message, "content", ""), "previous_history": formatted_messages, "chatbot_global_action": system, + "tool_results": tool_results, } +def _format_tool_calls_to_edenai_tool_calls(message: BaseMessage) -> List: + tool_calls = getattr(message, "tool_calls", []) + invalid_tool_calls = getattr(message, "invalid_tool_calls", []) + edenai_tool_calls = [] + + for invalid_tool_call in invalid_tool_calls: + edenai_tool_calls.append( + { + "arguments": invalid_tool_call.get("args"), + "id": invalid_tool_call.get("id"), + "name": invalid_tool_call.get("name"), + } + ) + + for tool_call in tool_calls: + tool_args = tool_call.get("args", {}) + try: + arguments = json.dumps(tool_args) + except TypeError: + arguments = str(tool_args) + edenai_tool_calls.append( + { + "arguments": arguments, + "id": tool_call["id"], + "name": tool_call["name"], + } + ) + return edenai_tool_calls + + +def _extract_tool_calls_from_edenai_response( + provider_response: Dict[str, Any], +) -> Tuple[List[ToolCall], List[InvalidToolCall]]: + tool_calls = [] + invalid_tool_calls = [] + + message = provider_response.get("message", {})[1] + + if raw_tool_calls := message.get("tool_calls"): + for raw_tool_call in raw_tool_calls: + try: + tool_calls.append( + ToolCall( + name=raw_tool_call["name"], + args=json.loads(raw_tool_call["arguments"]), + id=raw_tool_call["id"], + ) + ) + except json.JSONDecodeError as exc: + invalid_tool_calls.append( + InvalidToolCall( + name=raw_tool_call.get("name"), + args=raw_tool_call.get("arguments"), + id=raw_tool_call.get("id"), + error=f"Received JSONDecodeError {exc}", + ) + ) + + return tool_calls, invalid_tool_calls + + class ChatEdenAI(BaseChatModel): """`EdenAI` chat large language models. @@ -179,6 +333,11 @@ class ChatEdenAI(BaseChatModel): **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Call out to EdenAI's chat endpoint.""" + if "available_tools" in kwargs: + yield self._stream_with_tools_as_generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return url = f"{self.edenai_api_url}/text/chat/stream" headers = { "Authorization": f"Bearer {self._api_key}", @@ -218,6 +377,11 @@ class ChatEdenAI(BaseChatModel): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: + if "available_tools" in kwargs: + yield await self._astream_with_tools_as_agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return url = f"{self.edenai_api_url}/text/chat/stream" headers = { "Authorization": f"Bearer {self._api_key}", @@ -253,6 +417,53 @@ class ChatEdenAI(BaseChatModel): ) yield cg_chunk + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools] + formatted_tool_choice = "required" if tool_choice == "any" else tool_choice + return super().bind( + available_tools=formatted_tools, tool_choice=formatted_tool_choice, **kwargs + ) + + def with_structured_output( + self, + schema: Union[Dict, Type[BaseModel]], + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + llm = self.bind_tools([schema], tool_choice="required") + if isinstance(schema, type) and issubclass(schema, BaseModel): + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], first_tool_only=True + ) + else: + key_name = convert_to_openai_tool(schema)["function"]["name"] + output_parser = JsonOutputKeyToolsParser( + key_name=key_name, first_tool_only=True + ) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser + def _generate( self, messages: List[BaseMessage], @@ -262,10 +473,15 @@ class ChatEdenAI(BaseChatModel): ) -> ChatResult: """Call out to EdenAI's chat endpoint.""" if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) + if "available_tools" in kwargs: + warnings.warn( + "stream: Tool use is not yet supported in streaming mode." + ) + else: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) url = f"{self.edenai_api_url}/text/chat" headers = { @@ -273,6 +489,7 @@ class ChatEdenAI(BaseChatModel): "User-Agent": self.get_user_agent(), } formatted_data = _format_edenai_messages(messages=messages) + payload: Dict[str, Any] = { "providers": self.provider, "max_tokens": self.max_tokens, @@ -303,10 +520,18 @@ class ChatEdenAI(BaseChatModel): err_msg = provider_response.get("error", {}).get("message") raise Exception(err_msg) + tool_calls, invalid_tool_calls = _extract_tool_calls_from_edenai_response( + provider_response + ) + return ChatResult( generations=[ ChatGeneration( - message=AIMessage(content=provider_response["generated_text"]) + message=AIMessage( + content=provider_response["generated_text"] or "", + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ) ) ], llm_output=data, @@ -320,10 +545,15 @@ class ChatEdenAI(BaseChatModel): **kwargs: Any, ) -> ChatResult: if self.streaming: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) + if "available_tools" in kwargs: + warnings.warn( + "stream: Tool use is not yet supported in streaming mode." + ) + else: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) url = f"{self.edenai_api_url}/text/chat" headers = { @@ -370,3 +600,27 @@ class ChatEdenAI(BaseChatModel): ], llm_output=data, ) + + def _stream_with_tools_as_generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]], + run_manager: Optional[CallbackManagerForLLMRun], + **kwargs: Any, + ) -> ChatGenerationChunk: + warnings.warn("stream: Tool use is not yet supported in streaming mode.") + result = self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + return _result_to_chunked_message(result) + + async def _astream_with_tools_as_agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]], + run_manager: Optional[AsyncCallbackManagerForLLMRun], + **kwargs: Any, + ) -> ChatGenerationChunk: + warnings.warn("stream: Tool use is not yet supported in streaming mode.") + result = await self._agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return _result_to_chunked_message(result) diff --git a/libs/community/tests/unit_tests/chat_models/test_edenai.py b/libs/community/tests/unit_tests/chat_models/test_edenai.py index dfafc5af988..aa4edf29332 100644 --- a/libs/community/tests/unit_tests/chat_models/test_edenai.py +++ b/libs/community/tests/unit_tests/chat_models/test_edenai.py @@ -2,9 +2,15 @@ from typing import List import pytest -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import ( + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) from langchain_community.chat_models.edenai import ( + _extract_edenai_tool_results_from_messages, _format_edenai_messages, _message_role, ) @@ -22,6 +28,7 @@ from langchain_community.chat_models.edenai import ( "text": "Hello how are you today?", "previous_history": [], "chatbot_global_action": "Translate the text from English to French", + "tool_results": [], }, ) ], @@ -38,3 +45,26 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str) def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def] role = _message_role(role) assert role == role_response + + +def test_extract_edenai_tool_results_mixed_messages() -> None: + fake_other_msg = BaseMessage(content="content", type="other message") + messages = [ + fake_other_msg, + ToolMessage(tool_call_id="id1", content="result1"), + fake_other_msg, + ToolMessage(tool_call_id="id2", content="result2"), + ToolMessage(tool_call_id="id3", content="result3"), + ] + expected_tool_results = [ + {"id": "id2", "result": "result2"}, + {"id": "id3", "result": "result3"}, + ] + expected_other_messages = [ + fake_other_msg, + ToolMessage(tool_call_id="id1", content="result1"), + fake_other_msg, + ] + tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages) + assert tool_results == expected_tool_results + assert other_messages == expected_other_messages