community: fix ChatEdenAI + EdenAI Tools (#23715)

Fixes for Eden AI Custom tools and ChatEdenAI:
- add missing import in __init__ of chat_models
- add `args_schema` to custom tools. otherwise '__arg1' would sometimes
be passed to the `run` method
- fix IndexError when no human msg is added in ChatEdenAI
This commit is contained in:
KyrianC 2024-07-25 21:19:14 +02:00 committed by GitHub
parent 871bf5a841
commit 0fdbaf4a8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 55 additions and 14 deletions

View File

@ -51,6 +51,7 @@ if TYPE_CHECKING:
from langchain_community.chat_models.deepinfra import ( from langchain_community.chat_models.deepinfra import (
ChatDeepInfra, ChatDeepInfra,
) )
from langchain_community.chat_models.edenai import ChatEdenAI
from langchain_community.chat_models.ernie import ( from langchain_community.chat_models.ernie import (
ErnieBotChat, ErnieBotChat,
) )
@ -182,6 +183,7 @@ __all__ = [
"ChatOctoAI", "ChatOctoAI",
"ChatDatabricks", "ChatDatabricks",
"ChatDeepInfra", "ChatDeepInfra",
"ChatEdenAI",
"ChatEverlyAI", "ChatEverlyAI",
"ChatFireworks", "ChatFireworks",
"ChatFriendli", "ChatFriendli",
@ -237,6 +239,7 @@ _module_lookup = {
"ChatDatabricks": "langchain_community.chat_models.databricks", "ChatDatabricks": "langchain_community.chat_models.databricks",
"ChatDeepInfra": "langchain_community.chat_models.deepinfra", "ChatDeepInfra": "langchain_community.chat_models.deepinfra",
"ChatEverlyAI": "langchain_community.chat_models.everlyai", "ChatEverlyAI": "langchain_community.chat_models.everlyai",
"ChatEdenAI": "langchain_community.chat_models.edenai",
"ChatFireworks": "langchain_community.chat_models.fireworks", "ChatFireworks": "langchain_community.chat_models.fireworks",
"ChatFriendli": "langchain_community.chat_models.friendli", "ChatFriendli": "langchain_community.chat_models.friendli",
"ChatGooglePalm": "langchain_community.chat_models.google_palm", "ChatGooglePalm": "langchain_community.chat_models.google_palm",

View File

@ -122,8 +122,8 @@ def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
system = None system = None
formatted_messages = [] formatted_messages = []
human_messages = filter(lambda msg: isinstance(msg, HumanMessage), messages) human_messages = list(filter(lambda msg: isinstance(msg, HumanMessage), messages))
last_human_message = list(human_messages)[-1] if human_messages else "" last_human_message = human_messages[-1] if human_messages else ""
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages) tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
for i, message in enumerate(other_messages): for i, message in enumerate(other_messages):

View File

@ -3,17 +3,21 @@ from __future__ import annotations
import json import json
import logging import logging
import time import time
from typing import List, Optional from typing import List, Optional, Type
import requests import requests
from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import validator from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl, validator
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SpeechToTextInput(BaseModel):
query: HttpUrl = Field(description="url of the audio to analyze")
class EdenAiSpeechToTextTool(EdenaiTool): class EdenAiSpeechToTextTool(EdenaiTool):
"""Tool that queries the Eden AI Speech To Text API. """Tool that queries the Eden AI Speech To Text API.
@ -23,7 +27,6 @@ class EdenAiSpeechToTextTool(EdenaiTool):
To use, you should have To use, you should have
the environment variable ``EDENAI_API_KEY`` set with your API token. the environment variable ``EDENAI_API_KEY`` set with your API token.
You can find your token here: https://app.edenai.run/admin/account/settings You can find your token here: https://app.edenai.run/admin/account/settings
""" """
edenai_api_key: Optional[str] = None edenai_api_key: Optional[str] = None
@ -34,6 +37,7 @@ class EdenAiSpeechToTextTool(EdenaiTool):
"Useful for when you have to convert audio to text." "Useful for when you have to convert audio to text."
"Input should be a url to an audio file." "Input should be a url to an audio file."
) )
args_schema: Type[BaseModel] = SpeechToTextInput
is_async: bool = True is_async: bool = True
language: Optional[str] = "en" language: Optional[str] = "en"

View File

@ -1,17 +1,21 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Dict, List, Literal, Optional from typing import Dict, List, Literal, Optional, Type
import requests import requests
from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import Field, root_validator, validator from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextToSpeechInput(BaseModel):
query: str = Field(description="text to generate audio from")
class EdenAiTextToSpeechTool(EdenaiTool): class EdenAiTextToSpeechTool(EdenaiTool):
"""Tool that queries the Eden AI Text to speech API. """Tool that queries the Eden AI Text to speech API.
for api reference check edenai documentation: for api reference check edenai documentation:
@ -30,6 +34,7 @@ class EdenAiTextToSpeechTool(EdenaiTool):
"""the output is a string representing the URL of the audio file, """the output is a string representing the URL of the audio file,
or the path to the downloaded wav file """ or the path to the downloaded wav file """
) )
args_schema: Type[BaseModel] = TextToSpeechInput
language: Optional[str] = "en" language: Optional[str] = "en"
""" """

View File

@ -1,15 +1,20 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Optional from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExplicitImageInput(BaseModel):
query: HttpUrl = Field(description="url of the image to analyze")
class EdenAiExplicitImageTool(EdenaiTool): class EdenAiExplicitImageTool(EdenaiTool):
"""Tool that queries the Eden AI Explicit image detection. """Tool that queries the Eden AI Explicit image detection.
@ -33,6 +38,7 @@ class EdenAiExplicitImageTool(EdenaiTool):
pornography, violence, gore content, etc.""" pornography, violence, gore content, etc."""
"Input should be the string url of the image ." "Input should be the string url of the image ."
) )
args_schema: Type[BaseModel] = ExplicitImageInput
combine_available: bool = True combine_available: bool = True
feature: str = "image" feature: str = "image"

View File

@ -1,15 +1,20 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Optional from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ObjectDetectionInput(BaseModel):
query: HttpUrl = Field(description="url of the image to analyze")
class EdenAiObjectDetectionTool(EdenaiTool): class EdenAiObjectDetectionTool(EdenaiTool):
"""Tool that queries the Eden AI Object detection API. """Tool that queries the Eden AI Object detection API.
@ -30,6 +35,7 @@ class EdenAiObjectDetectionTool(EdenaiTool):
(with bounding boxes) objects in an image """ (with bounding boxes) objects in an image """
"Input should be the string url of the image to identify." "Input should be the string url of the image to identify."
) )
args_schema: Type[BaseModel] = ObjectDetectionInput
show_positions: bool = False show_positions: bool = False

View File

@ -1,15 +1,20 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Optional from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class IDParsingInput(BaseModel):
query: HttpUrl = Field(description="url of the document to parse")
class EdenAiParsingIDTool(EdenaiTool): class EdenAiParsingIDTool(EdenaiTool):
"""Tool that queries the Eden AI Identity parsing API. """Tool that queries the Eden AI Identity parsing API.
@ -29,6 +34,7 @@ class EdenAiParsingIDTool(EdenaiTool):
"Useful for when you have to extract information from an ID Document " "Useful for when you have to extract information from an ID Document "
"Input should be the string url of the document to parse." "Input should be the string url of the document to parse."
) )
args_schema: Type[BaseModel] = IDParsingInput
feature: str = "ocr" feature: str = "ocr"
subfeature: str = "identity_parser" subfeature: str = "identity_parser"

View File

@ -1,15 +1,20 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Optional from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class InvoiceParsingInput(BaseModel):
query: HttpUrl = Field(description="url of the document to parse")
class EdenAiParsingInvoiceTool(EdenaiTool): class EdenAiParsingInvoiceTool(EdenaiTool):
"""Tool that queries the Eden AI Invoice parsing API. """Tool that queries the Eden AI Invoice parsing API.
@ -23,7 +28,6 @@ class EdenAiParsingInvoiceTool(EdenaiTool):
""" """
name: str = "edenai_invoice_parsing" name: str = "edenai_invoice_parsing"
description: str = ( description: str = (
"A wrapper around edenai Services invoice parsing. " "A wrapper around edenai Services invoice parsing. "
"""Useful for when you have to extract information from """Useful for when you have to extract information from
@ -33,6 +37,7 @@ class EdenAiParsingInvoiceTool(EdenaiTool):
in a structured format to automate the invoice processing """ in a structured format to automate the invoice processing """
"Input should be the string url of the document to parse." "Input should be the string url of the document to parse."
) )
args_schema: Type[BaseModel] = InvoiceParsingInput
language: Optional[str] = None language: Optional[str] = None
""" """

View File

@ -1,15 +1,20 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Optional from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextModerationInput(BaseModel):
query: str = Field(description="Text to moderate")
class EdenAiTextModerationTool(EdenaiTool): class EdenAiTextModerationTool(EdenaiTool):
"""Tool that queries the Eden AI Explicit text detection. """Tool that queries the Eden AI Explicit text detection.
@ -23,7 +28,6 @@ class EdenAiTextModerationTool(EdenaiTool):
""" """
name: str = "edenai_explicit_content_detection_text" name: str = "edenai_explicit_content_detection_text"
description: str = ( description: str = (
"A wrapper around edenai Services explicit content detection for text. " "A wrapper around edenai Services explicit content detection for text. "
"""Useful for when you have to scan text for offensive, """Useful for when you have to scan text for offensive,
@ -44,6 +48,7 @@ class EdenAiTextModerationTool(EdenaiTool):
""" """
"Input should be a string." "Input should be a string."
) )
args_schema: Type[BaseModel] = TextModerationInput
language: str language: str

View File

@ -11,6 +11,7 @@ EXPECTED_ALL = [
"ChatDatabricks", "ChatDatabricks",
"ChatDeepInfra", "ChatDeepInfra",
"ChatEverlyAI", "ChatEverlyAI",
"ChatEdenAI",
"ChatFireworks", "ChatFireworks",
"ChatFriendli", "ChatFriendli",
"ChatGooglePalm", "ChatGooglePalm",