mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
community[minor]: Add tools calls to ChatEdenAI
(#22320)
### Description Add tools implementation to `ChatEdenAI`: - `bind_tools()` - `with_structured_output()` ### Documentation Updated `docs/docs/integrations/chat/edenai.ipynb` ### Notes We don´t support stream with tools as of yet. If stream is called with tools we directly yield the whole message from `generate` (implemented the same way as Anthropic did).
This commit is contained in:
parent
9d4350e69a
commit
03178ee74f
@ -246,11 +246,220 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"chain.invoke({\"product\": \"healthy snacks\"})"
|
"chain.invoke({\"product\": \"healthy snacks\"})"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Tools\n",
|
||||||
|
"\n",
|
||||||
|
"### bind_tools()\n",
|
||||||
|
"\n",
|
||||||
|
"With `ChatEdenAI.bind_tools`, we can easily pass in Pydantic classes, dict schemas, LangChain tools, or even functions as tools to the model."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatEdenAI(provider=\"openai\", temperature=0.2, max_tokens=500)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class GetWeather(BaseModel):\n",
|
||||||
|
" \"\"\"Get the current weather in a given location\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"llm_with_tools = llm.bind_tools([GetWeather])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content='', response_metadata={'openai': {'status': 'success', 'generated_text': None, 'message': [{'role': 'user', 'message': 'what is the weather like in San Francisco', 'tools': [{'name': 'GetWeather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'description': 'The city and state, e.g. San Francisco, CA', 'type': 'string'}}, 'required': ['location']}}], 'tool_calls': None}, {'role': 'assistant', 'message': None, 'tools': None, 'tool_calls': [{'id': 'call_tRpAO7KbQwgTjlka70mCQJdo', 'name': 'GetWeather', 'arguments': '{\"location\":\"San Francisco\"}'}]}], 'cost': 0.000194}}, id='run-5c44c01a-d7bb-4df6-835e-bda596080399-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco'}, 'id': 'call_tRpAO7KbQwgTjlka70mCQJdo'}])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"ai_msg = llm_with_tools.invoke(\n",
|
||||||
|
" \"what is the weather like in San Francisco\",\n",
|
||||||
|
")\n",
|
||||||
|
"ai_msg"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[{'name': 'GetWeather',\n",
|
||||||
|
" 'args': {'location': 'San Francisco'},\n",
|
||||||
|
" 'id': 'call_tRpAO7KbQwgTjlka70mCQJdo'}]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 17,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"ai_msg.tool_calls"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### with_structured_output()\n",
|
||||||
|
"\n",
|
||||||
|
"The BaseChatModel.with_structured_output interface makes it easy to get structured output from chat models. You can use ChatEdenAI.with_structured_output, which uses tool-calling under the hood), to get the model to more reliably return an output in a specific format:\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"GetWeather(location='San Francisco')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 18,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"structured_llm = llm.with_structured_output(GetWeather)\n",
|
||||||
|
"structured_llm.invoke(\n",
|
||||||
|
" \"what is the weather like in San Francisco\",\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Passing Tool Results to model\n",
|
||||||
|
"\n",
|
||||||
|
"Here is a full example of how to use a tool. Pass the tool output to the model, and get the result back from the model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'11 + 11 = 22'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 19,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain_core.messages import HumanMessage, ToolMessage\n",
|
||||||
|
"from langchain_core.tools import tool\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"@tool\n",
|
||||||
|
"def add(a: int, b: int) -> int:\n",
|
||||||
|
" \"\"\"Adds a and b.\n",
|
||||||
|
"\n",
|
||||||
|
" Args:\n",
|
||||||
|
" a: first int\n",
|
||||||
|
" b: second int\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" return a + b\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatEdenAI(\n",
|
||||||
|
" provider=\"openai\",\n",
|
||||||
|
" max_tokens=1000,\n",
|
||||||
|
" temperature=0.2,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"llm_with_tools = llm.bind_tools([add], tool_choice=\"required\")\n",
|
||||||
|
"\n",
|
||||||
|
"query = \"What is 11 + 11?\"\n",
|
||||||
|
"\n",
|
||||||
|
"messages = [HumanMessage(query)]\n",
|
||||||
|
"ai_msg = llm_with_tools.invoke(messages)\n",
|
||||||
|
"messages.append(ai_msg)\n",
|
||||||
|
"\n",
|
||||||
|
"tool_call = ai_msg.tool_calls[0]\n",
|
||||||
|
"tool_output = add.invoke(tool_call[\"args\"])\n",
|
||||||
|
"\n",
|
||||||
|
"# This append the result from our tool to the model\n",
|
||||||
|
"messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n",
|
||||||
|
"\n",
|
||||||
|
"llm_with_tools.invoke(messages).content"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Streaming\n",
|
||||||
|
"\n",
|
||||||
|
"Eden AI does not currently support streaming tool calls. Attempting to stream will yield a single final message."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/eden/Projects/edenai-langchain/libs/community/langchain_community/chat_models/edenai.py:603: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
|
||||||
|
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[AIMessageChunk(content='', id='run-fae32908-ec48-4ab2-ad96-bb0d0511754f', tool_calls=[{'name': 'add', 'args': {'a': 9, 'b': 9}, 'id': 'call_n0Tm7I9zERWa6UpxCAVCweLN'}], tool_call_chunks=[{'name': 'add', 'args': '{\"a\": 9, \"b\": 9}', 'id': 'call_n0Tm7I9zERWa6UpxCAVCweLN', 'index': 0}])]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"list(llm_with_tools.stream(\"What's 9 + 9\"))"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "langchain-pr",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
@ -96,6 +96,12 @@ CHAT_MODEL_FEAT_TABLE = {
|
|||||||
"package": "langchain-community",
|
"package": "langchain-community",
|
||||||
"link": "/docs/integrations/chat/vllm/",
|
"link": "/docs/integrations/chat/vllm/",
|
||||||
},
|
},
|
||||||
|
"ChatEdenAI": {
|
||||||
|
"tool_calling": True,
|
||||||
|
"structured_output": True,
|
||||||
|
"package": "langchain-community",
|
||||||
|
"link": "/docs/integrations/chat/edenai/",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,28 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
import warnings
|
||||||
|
from operator import itemgetter
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import (
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
agenerate_from_stream,
|
agenerate_from_stream,
|
||||||
@ -15,16 +32,62 @@ from langchain_core.messages import (
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
InvalidToolCall,
|
||||||
|
SystemMessage,
|
||||||
|
ToolCall,
|
||||||
|
ToolCallChunk,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
JsonOutputKeyToolsParser,
|
||||||
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Extra,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
|
||||||
from langchain_community.utilities.requests import Requests
|
from langchain_community.utilities.requests import Requests
|
||||||
|
|
||||||
|
|
||||||
|
def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationChunk:
|
||||||
|
message = generated_result.generations[0].message
|
||||||
|
if isinstance(message, AIMessage) and message.tool_calls is not None:
|
||||||
|
tool_call_chunks = [
|
||||||
|
ToolCallChunk(
|
||||||
|
name=tool_call["name"],
|
||||||
|
args=json.dumps(tool_call["args"]),
|
||||||
|
id=tool_call["id"],
|
||||||
|
index=idx,
|
||||||
|
)
|
||||||
|
for idx, tool_call in enumerate(message.tool_calls)
|
||||||
|
]
|
||||||
|
message_chunk = AIMessageChunk(
|
||||||
|
content=message.content,
|
||||||
|
tool_call_chunks=tool_call_chunks,
|
||||||
|
)
|
||||||
|
return ChatGenerationChunk(message=message_chunk)
|
||||||
|
else:
|
||||||
|
return cast(ChatGenerationChunk, generated_result.generations[0])
|
||||||
|
|
||||||
|
|
||||||
def _message_role(type: str) -> str:
|
def _message_role(type: str) -> str:
|
||||||
role_mapping = {"ai": "assistant", "human": "user", "chat": "user"}
|
role_mapping = {
|
||||||
|
"ai": "assistant",
|
||||||
|
"human": "user",
|
||||||
|
"chat": "user",
|
||||||
|
"AIMessageChunk": "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
if type in role_mapping:
|
if type in role_mapping:
|
||||||
return role_mapping[type]
|
return role_mapping[type]
|
||||||
@ -32,29 +95,120 @@ def _message_role(type: str) -> str:
|
|||||||
raise ValueError(f"Unknown type: {type}")
|
raise ValueError(f"Unknown type: {type}")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_edenai_tool_results_from_messages(
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
) -> Tuple[List[Dict[str, Any]], List[BaseMessage]]:
|
||||||
|
"""
|
||||||
|
Get the last langchain tools messages to transform them into edenai tool_results
|
||||||
|
Returns tool_results and messages without the extracted tool messages
|
||||||
|
"""
|
||||||
|
tool_results: List[Dict[str, Any]] = []
|
||||||
|
other_messages = messages[:]
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, ToolMessage):
|
||||||
|
tool_results = [
|
||||||
|
{"id": msg.tool_call_id, "result": msg.content},
|
||||||
|
*tool_results,
|
||||||
|
]
|
||||||
|
other_messages.pop()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return tool_results, other_messages
|
||||||
|
|
||||||
|
|
||||||
def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
|
def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||||
system = None
|
system = None
|
||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
text = messages[-1].content
|
|
||||||
for i, message in enumerate(messages[:-1]):
|
human_messages = filter(lambda msg: isinstance(msg, HumanMessage), messages)
|
||||||
if message.type == "system":
|
last_human_message = list(human_messages)[-1] if human_messages else ""
|
||||||
|
|
||||||
|
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
|
||||||
|
for i, message in enumerate(other_messages):
|
||||||
|
if isinstance(message, SystemMessage):
|
||||||
if i != 0:
|
if i != 0:
|
||||||
raise ValueError("System message must be at beginning of message list.")
|
raise ValueError("System message must be at beginning of message list.")
|
||||||
system = message.content
|
system = message.content
|
||||||
else:
|
elif isinstance(message, ToolMessage):
|
||||||
|
formatted_messages.append({"role": "tool", "message": message.content})
|
||||||
|
elif message != last_human_message:
|
||||||
formatted_messages.append(
|
formatted_messages.append(
|
||||||
{
|
{
|
||||||
"role": _message_role(message.type),
|
"role": _message_role(message.type),
|
||||||
"message": message.content,
|
"message": message.content,
|
||||||
|
"tool_calls": _format_tool_calls_to_edenai_tool_calls(message),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": getattr(last_human_message, "content", ""),
|
||||||
"previous_history": formatted_messages,
|
"previous_history": formatted_messages,
|
||||||
"chatbot_global_action": system,
|
"chatbot_global_action": system,
|
||||||
|
"tool_results": tool_results,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool_calls_to_edenai_tool_calls(message: BaseMessage) -> List:
|
||||||
|
tool_calls = getattr(message, "tool_calls", [])
|
||||||
|
invalid_tool_calls = getattr(message, "invalid_tool_calls", [])
|
||||||
|
edenai_tool_calls = []
|
||||||
|
|
||||||
|
for invalid_tool_call in invalid_tool_calls:
|
||||||
|
edenai_tool_calls.append(
|
||||||
|
{
|
||||||
|
"arguments": invalid_tool_call.get("args"),
|
||||||
|
"id": invalid_tool_call.get("id"),
|
||||||
|
"name": invalid_tool_call.get("name"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
tool_args = tool_call.get("args", {})
|
||||||
|
try:
|
||||||
|
arguments = json.dumps(tool_args)
|
||||||
|
except TypeError:
|
||||||
|
arguments = str(tool_args)
|
||||||
|
edenai_tool_calls.append(
|
||||||
|
{
|
||||||
|
"arguments": arguments,
|
||||||
|
"id": tool_call["id"],
|
||||||
|
"name": tool_call["name"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return edenai_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_calls_from_edenai_response(
|
||||||
|
provider_response: Dict[str, Any],
|
||||||
|
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
|
||||||
|
tool_calls = []
|
||||||
|
invalid_tool_calls = []
|
||||||
|
|
||||||
|
message = provider_response.get("message", {})[1]
|
||||||
|
|
||||||
|
if raw_tool_calls := message.get("tool_calls"):
|
||||||
|
for raw_tool_call in raw_tool_calls:
|
||||||
|
try:
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
name=raw_tool_call["name"],
|
||||||
|
args=json.loads(raw_tool_call["arguments"]),
|
||||||
|
id=raw_tool_call["id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
invalid_tool_calls.append(
|
||||||
|
InvalidToolCall(
|
||||||
|
name=raw_tool_call.get("name"),
|
||||||
|
args=raw_tool_call.get("arguments"),
|
||||||
|
id=raw_tool_call.get("id"),
|
||||||
|
error=f"Received JSONDecodeError {exc}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_calls, invalid_tool_calls
|
||||||
|
|
||||||
|
|
||||||
class ChatEdenAI(BaseChatModel):
|
class ChatEdenAI(BaseChatModel):
|
||||||
"""`EdenAI` chat large language models.
|
"""`EdenAI` chat large language models.
|
||||||
|
|
||||||
@ -179,6 +333,11 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
"""Call out to EdenAI's chat endpoint."""
|
"""Call out to EdenAI's chat endpoint."""
|
||||||
|
if "available_tools" in kwargs:
|
||||||
|
yield self._stream_with_tools_as_generate(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return
|
||||||
url = f"{self.edenai_api_url}/text/chat/stream"
|
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self._api_key}",
|
"Authorization": f"Bearer {self._api_key}",
|
||||||
@ -218,6 +377,11 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
if "available_tools" in kwargs:
|
||||||
|
yield await self._astream_with_tools_as_agenerate(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return
|
||||||
url = f"{self.edenai_api_url}/text/chat/stream"
|
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self._api_key}",
|
"Authorization": f"Bearer {self._api_key}",
|
||||||
@ -253,6 +417,53 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
yield cg_chunk
|
yield cg_chunk
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||||
|
*,
|
||||||
|
tool_choice: Optional[
|
||||||
|
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||||
|
] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools]
|
||||||
|
formatted_tool_choice = "required" if tool_choice == "any" else tool_choice
|
||||||
|
return super().bind(
|
||||||
|
available_tools=formatted_tools, tool_choice=formatted_tool_choice, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_structured_output(
|
||||||
|
self,
|
||||||
|
schema: Union[Dict, Type[BaseModel]],
|
||||||
|
*,
|
||||||
|
include_raw: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
|
llm = self.bind_tools([schema], tool_choice="required")
|
||||||
|
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||||
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
|
tools=[schema], first_tool_only=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||||
|
output_parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name=key_name, first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_raw:
|
||||||
|
parser_assign = RunnablePassthrough.assign(
|
||||||
|
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||||
|
)
|
||||||
|
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||||
|
parser_with_fallback = parser_assign.with_fallbacks(
|
||||||
|
[parser_none], exception_key="parsing_error"
|
||||||
|
)
|
||||||
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
|
else:
|
||||||
|
return llm | output_parser
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -262,6 +473,11 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Call out to EdenAI's chat endpoint."""
|
"""Call out to EdenAI's chat endpoint."""
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
|
if "available_tools" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"stream: Tool use is not yet supported in streaming mode."
|
||||||
|
)
|
||||||
|
else:
|
||||||
stream_iter = self._stream(
|
stream_iter = self._stream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
@ -273,6 +489,7 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
"User-Agent": self.get_user_agent(),
|
"User-Agent": self.get_user_agent(),
|
||||||
}
|
}
|
||||||
formatted_data = _format_edenai_messages(messages=messages)
|
formatted_data = _format_edenai_messages(messages=messages)
|
||||||
|
|
||||||
payload: Dict[str, Any] = {
|
payload: Dict[str, Any] = {
|
||||||
"providers": self.provider,
|
"providers": self.provider,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
@ -303,10 +520,18 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
err_msg = provider_response.get("error", {}).get("message")
|
err_msg = provider_response.get("error", {}).get("message")
|
||||||
raise Exception(err_msg)
|
raise Exception(err_msg)
|
||||||
|
|
||||||
|
tool_calls, invalid_tool_calls = _extract_tool_calls_from_edenai_response(
|
||||||
|
provider_response
|
||||||
|
)
|
||||||
|
|
||||||
return ChatResult(
|
return ChatResult(
|
||||||
generations=[
|
generations=[
|
||||||
ChatGeneration(
|
ChatGeneration(
|
||||||
message=AIMessage(content=provider_response["generated_text"])
|
message=AIMessage(
|
||||||
|
content=provider_response["generated_text"] or "",
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
llm_output=data,
|
llm_output=data,
|
||||||
@ -320,6 +545,11 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
|
if "available_tools" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"stream: Tool use is not yet supported in streaming mode."
|
||||||
|
)
|
||||||
|
else:
|
||||||
stream_iter = self._astream(
|
stream_iter = self._astream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
@ -370,3 +600,27 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
],
|
],
|
||||||
llm_output=data,
|
llm_output=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _stream_with_tools_as_generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]],
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatGenerationChunk:
|
||||||
|
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||||
|
result = self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
|
return _result_to_chunked_message(result)
|
||||||
|
|
||||||
|
async def _astream_with_tools_as_agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatGenerationChunk:
|
||||||
|
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||||
|
result = await self._agenerate(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return _result_to_chunked_message(result)
|
||||||
|
@ -2,9 +2,15 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import (
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_community.chat_models.edenai import (
|
from langchain_community.chat_models.edenai import (
|
||||||
|
_extract_edenai_tool_results_from_messages,
|
||||||
_format_edenai_messages,
|
_format_edenai_messages,
|
||||||
_message_role,
|
_message_role,
|
||||||
)
|
)
|
||||||
@ -22,6 +28,7 @@ from langchain_community.chat_models.edenai import (
|
|||||||
"text": "Hello how are you today?",
|
"text": "Hello how are you today?",
|
||||||
"previous_history": [],
|
"previous_history": [],
|
||||||
"chatbot_global_action": "Translate the text from English to French",
|
"chatbot_global_action": "Translate the text from English to French",
|
||||||
|
"tool_results": [],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -38,3 +45,26 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str)
|
|||||||
def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def]
|
def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def]
|
||||||
role = _message_role(role)
|
role = _message_role(role)
|
||||||
assert role == role_response
|
assert role == role_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_edenai_tool_results_mixed_messages() -> None:
|
||||||
|
fake_other_msg = BaseMessage(content="content", type="other message")
|
||||||
|
messages = [
|
||||||
|
fake_other_msg,
|
||||||
|
ToolMessage(tool_call_id="id1", content="result1"),
|
||||||
|
fake_other_msg,
|
||||||
|
ToolMessage(tool_call_id="id2", content="result2"),
|
||||||
|
ToolMessage(tool_call_id="id3", content="result3"),
|
||||||
|
]
|
||||||
|
expected_tool_results = [
|
||||||
|
{"id": "id2", "result": "result2"},
|
||||||
|
{"id": "id3", "result": "result3"},
|
||||||
|
]
|
||||||
|
expected_other_messages = [
|
||||||
|
fake_other_msg,
|
||||||
|
ToolMessage(tool_call_id="id1", content="result1"),
|
||||||
|
fake_other_msg,
|
||||||
|
]
|
||||||
|
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
|
||||||
|
assert tool_results == expected_tool_results
|
||||||
|
assert other_messages == expected_other_messages
|
||||||
|
Loading…
Reference in New Issue
Block a user