From 0fdbaf4a8dd0923662122b9f80809fa2d37e4bc6 Mon Sep 17 00:00:00 2001 From: KyrianC <67210837+KyrianC@users.noreply.github.com> Date: Thu, 25 Jul 2024 21:19:14 +0200 Subject: [PATCH] community: fix ChatEdenAI + EdenAI Tools (#23715) Fixes for Eden AI Custom tools and ChatEdenAI: - add missing import in __init__ of chat_models - add `args_schema` to custom tools. otherwise '__arg1' would sometimes be passed to the `run` method - fix IndexError when no human msg is added in ChatEdenAI --- .../langchain_community/chat_models/__init__.py | 3 +++ .../langchain_community/chat_models/edenai.py | 4 ++-- .../tools/edenai/audio_speech_to_text.py | 10 +++++++--- .../tools/edenai/audio_text_to_speech.py | 9 +++++++-- .../tools/edenai/image_explicitcontent.py | 8 +++++++- .../tools/edenai/image_objectdetection.py | 8 +++++++- .../tools/edenai/ocr_identityparser.py | 8 +++++++- .../tools/edenai/ocr_invoiceparser.py | 9 +++++++-- .../tools/edenai/text_moderation.py | 9 +++++++-- .../tests/unit_tests/chat_models/test_imports.py | 1 + 10 files changed, 55 insertions(+), 14 deletions(-) diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index af25b60184d..3d0d47878f4 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -51,6 +51,7 @@ if TYPE_CHECKING: from langchain_community.chat_models.deepinfra import ( ChatDeepInfra, ) + from langchain_community.chat_models.edenai import ChatEdenAI from langchain_community.chat_models.ernie import ( ErnieBotChat, ) @@ -182,6 +183,7 @@ __all__ = [ "ChatOctoAI", "ChatDatabricks", "ChatDeepInfra", + "ChatEdenAI", "ChatEverlyAI", "ChatFireworks", "ChatFriendli", @@ -237,6 +239,7 @@ _module_lookup = { "ChatDatabricks": "langchain_community.chat_models.databricks", "ChatDeepInfra": "langchain_community.chat_models.deepinfra", "ChatEverlyAI": "langchain_community.chat_models.everlyai", + "ChatEdenAI": "langchain_community.chat_models.edenai", "ChatFireworks": "langchain_community.chat_models.fireworks", "ChatFriendli": "langchain_community.chat_models.friendli", "ChatGooglePalm": "langchain_community.chat_models.google_palm", diff --git a/libs/community/langchain_community/chat_models/edenai.py b/libs/community/langchain_community/chat_models/edenai.py index 3cf1f16eeaf..384e80d72d0 100644 --- a/libs/community/langchain_community/chat_models/edenai.py +++ b/libs/community/langchain_community/chat_models/edenai.py @@ -122,8 +122,8 @@ def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]: system = None formatted_messages = [] - human_messages = filter(lambda msg: isinstance(msg, HumanMessage), messages) - last_human_message = list(human_messages)[-1] if human_messages else "" + human_messages = list(filter(lambda msg: isinstance(msg, HumanMessage), messages)) + last_human_message = human_messages[-1] if human_messages else "" tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages) for i, message in enumerate(other_messages): diff --git a/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py b/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py index ba2978399ad..dc772ece593 100644 --- a/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py +++ b/libs/community/langchain_community/tools/edenai/audio_speech_to_text.py @@ -3,17 +3,21 @@ from __future__ import annotations import json import logging import time -from typing import List, Optional +from typing import List, Optional, Type import requests from langchain_core.callbacks import CallbackManagerForToolRun -from langchain_core.pydantic_v1 import validator +from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl, validator from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool logger = logging.getLogger(__name__) +class SpeechToTextInput(BaseModel): + query: HttpUrl = Field(description="url of the audio to analyze") + + class EdenAiSpeechToTextTool(EdenaiTool): """Tool that queries the Eden AI Speech To Text API. @@ -23,7 +27,6 @@ class EdenAiSpeechToTextTool(EdenaiTool): To use, you should have the environment variable ``EDENAI_API_KEY`` set with your API token. You can find your token here: https://app.edenai.run/admin/account/settings - """ edenai_api_key: Optional[str] = None @@ -34,6 +37,7 @@ class EdenAiSpeechToTextTool(EdenaiTool): "Useful for when you have to convert audio to text." "Input should be a url to an audio file." ) + args_schema: Type[BaseModel] = SpeechToTextInput is_async: bool = True language: Optional[str] = "en" diff --git a/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py b/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py index 421d06f5b00..e78cf35419b 100644 --- a/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py +++ b/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py @@ -1,17 +1,21 @@ from __future__ import annotations import logging -from typing import Dict, List, Literal, Optional +from typing import Dict, List, Literal, Optional, Type import requests from langchain_core.callbacks import CallbackManagerForToolRun -from langchain_core.pydantic_v1 import Field, root_validator, validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool logger = logging.getLogger(__name__) +class TextToSpeechInput(BaseModel): + query: str = Field(description="text to generate audio from") + + class EdenAiTextToSpeechTool(EdenaiTool): """Tool that queries the Eden AI Text to speech API. for api reference check edenai documentation: @@ -30,6 +34,7 @@ class EdenAiTextToSpeechTool(EdenaiTool): """the output is a string representing the URL of the audio file, or the path to the downloaded wav file """ ) + args_schema: Type[BaseModel] = TextToSpeechInput language: Optional[str] = "en" """ diff --git a/libs/community/langchain_community/tools/edenai/image_explicitcontent.py b/libs/community/langchain_community/tools/edenai/image_explicitcontent.py index 3dac6622aa1..8ca1d7739cb 100644 --- a/libs/community/langchain_community/tools/edenai/image_explicitcontent.py +++ b/libs/community/langchain_community/tools/edenai/image_explicitcontent.py @@ -1,15 +1,20 @@ from __future__ import annotations import logging -from typing import Optional +from typing import Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool logger = logging.getLogger(__name__) +class ExplicitImageInput(BaseModel): + query: HttpUrl = Field(description="url of the image to analyze") + + class EdenAiExplicitImageTool(EdenaiTool): """Tool that queries the Eden AI Explicit image detection. @@ -33,6 +38,7 @@ class EdenAiExplicitImageTool(EdenaiTool): pornography, violence, gore content, etc.""" "Input should be the string url of the image ." ) + args_schema: Type[BaseModel] = ExplicitImageInput combine_available: bool = True feature: str = "image" diff --git a/libs/community/langchain_community/tools/edenai/image_objectdetection.py b/libs/community/langchain_community/tools/edenai/image_objectdetection.py index 03b9fc36e58..1098e8e37f6 100644 --- a/libs/community/langchain_community/tools/edenai/image_objectdetection.py +++ b/libs/community/langchain_community/tools/edenai/image_objectdetection.py @@ -1,15 +1,20 @@ from __future__ import annotations import logging -from typing import Optional +from typing import Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool logger = logging.getLogger(__name__) +class ObjectDetectionInput(BaseModel): + query: HttpUrl = Field(description="url of the image to analyze") + + class EdenAiObjectDetectionTool(EdenaiTool): """Tool that queries the Eden AI Object detection API. @@ -30,6 +35,7 @@ class EdenAiObjectDetectionTool(EdenaiTool): (with bounding boxes) objects in an image """ "Input should be the string url of the image to identify." ) + args_schema: Type[BaseModel] = ObjectDetectionInput show_positions: bool = False diff --git a/libs/community/langchain_community/tools/edenai/ocr_identityparser.py b/libs/community/langchain_community/tools/edenai/ocr_identityparser.py index 75352312e58..2e208dbb543 100644 --- a/libs/community/langchain_community/tools/edenai/ocr_identityparser.py +++ b/libs/community/langchain_community/tools/edenai/ocr_identityparser.py @@ -1,15 +1,20 @@ from __future__ import annotations import logging -from typing import Optional +from typing import Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool logger = logging.getLogger(__name__) +class IDParsingInput(BaseModel): + query: HttpUrl = Field(description="url of the document to parse") + + class EdenAiParsingIDTool(EdenaiTool): """Tool that queries the Eden AI Identity parsing API. @@ -29,6 +34,7 @@ class EdenAiParsingIDTool(EdenaiTool): "Useful for when you have to extract information from an ID Document " "Input should be the string url of the document to parse." ) + args_schema: Type[BaseModel] = IDParsingInput feature: str = "ocr" subfeature: str = "identity_parser" diff --git a/libs/community/langchain_community/tools/edenai/ocr_invoiceparser.py b/libs/community/langchain_community/tools/edenai/ocr_invoiceparser.py index 4413beedf7b..75c8425154a 100644 --- a/libs/community/langchain_community/tools/edenai/ocr_invoiceparser.py +++ b/libs/community/langchain_community/tools/edenai/ocr_invoiceparser.py @@ -1,15 +1,20 @@ from __future__ import annotations import logging -from typing import Optional +from typing import Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool logger = logging.getLogger(__name__) +class InvoiceParsingInput(BaseModel): + query: HttpUrl = Field(description="url of the document to parse") + + class EdenAiParsingInvoiceTool(EdenaiTool): """Tool that queries the Eden AI Invoice parsing API. @@ -23,7 +28,6 @@ class EdenAiParsingInvoiceTool(EdenaiTool): """ name: str = "edenai_invoice_parsing" - description: str = ( "A wrapper around edenai Services invoice parsing. " """Useful for when you have to extract information from @@ -33,6 +37,7 @@ class EdenAiParsingInvoiceTool(EdenaiTool): in a structured format to automate the invoice processing """ "Input should be the string url of the document to parse." ) + args_schema: Type[BaseModel] = InvoiceParsingInput language: Optional[str] = None """ diff --git a/libs/community/langchain_community/tools/edenai/text_moderation.py b/libs/community/langchain_community/tools/edenai/text_moderation.py index 2486287fba1..9aed36f0b73 100644 --- a/libs/community/langchain_community/tools/edenai/text_moderation.py +++ b/libs/community/langchain_community/tools/edenai/text_moderation.py @@ -1,15 +1,20 @@ from __future__ import annotations import logging -from typing import Optional +from typing import Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool logger = logging.getLogger(__name__) +class TextModerationInput(BaseModel): + query: str = Field(description="Text to moderate") + + class EdenAiTextModerationTool(EdenaiTool): """Tool that queries the Eden AI Explicit text detection. @@ -23,7 +28,6 @@ class EdenAiTextModerationTool(EdenaiTool): """ name: str = "edenai_explicit_content_detection_text" - description: str = ( "A wrapper around edenai Services explicit content detection for text. " """Useful for when you have to scan text for offensive, @@ -44,6 +48,7 @@ class EdenAiTextModerationTool(EdenaiTool): """ "Input should be a string." ) + args_schema: Type[BaseModel] = TextModerationInput language: str diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 3c9b5e22547..4c46c7203f4 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -11,6 +11,7 @@ EXPECTED_ALL = [ "ChatDatabricks", "ChatDeepInfra", "ChatEverlyAI", + "ChatEdenAI", "ChatFireworks", "ChatFriendli", "ChatGooglePalm",