mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
community: fix ChatEdenAI + EdenAI Tools (#23715)
Fixes for Eden AI Custom tools and ChatEdenAI: - add missing import in __init__ of chat_models - add `args_schema` to custom tools. otherwise '__arg1' would sometimes be passed to the `run` method - fix IndexError when no human msg is added in ChatEdenAI
This commit is contained in:
parent
871bf5a841
commit
0fdbaf4a8d
@ -51,6 +51,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_community.chat_models.deepinfra import (
|
from langchain_community.chat_models.deepinfra import (
|
||||||
ChatDeepInfra,
|
ChatDeepInfra,
|
||||||
)
|
)
|
||||||
|
from langchain_community.chat_models.edenai import ChatEdenAI
|
||||||
from langchain_community.chat_models.ernie import (
|
from langchain_community.chat_models.ernie import (
|
||||||
ErnieBotChat,
|
ErnieBotChat,
|
||||||
)
|
)
|
||||||
@ -182,6 +183,7 @@ __all__ = [
|
|||||||
"ChatOctoAI",
|
"ChatOctoAI",
|
||||||
"ChatDatabricks",
|
"ChatDatabricks",
|
||||||
"ChatDeepInfra",
|
"ChatDeepInfra",
|
||||||
|
"ChatEdenAI",
|
||||||
"ChatEverlyAI",
|
"ChatEverlyAI",
|
||||||
"ChatFireworks",
|
"ChatFireworks",
|
||||||
"ChatFriendli",
|
"ChatFriendli",
|
||||||
@ -237,6 +239,7 @@ _module_lookup = {
|
|||||||
"ChatDatabricks": "langchain_community.chat_models.databricks",
|
"ChatDatabricks": "langchain_community.chat_models.databricks",
|
||||||
"ChatDeepInfra": "langchain_community.chat_models.deepinfra",
|
"ChatDeepInfra": "langchain_community.chat_models.deepinfra",
|
||||||
"ChatEverlyAI": "langchain_community.chat_models.everlyai",
|
"ChatEverlyAI": "langchain_community.chat_models.everlyai",
|
||||||
|
"ChatEdenAI": "langchain_community.chat_models.edenai",
|
||||||
"ChatFireworks": "langchain_community.chat_models.fireworks",
|
"ChatFireworks": "langchain_community.chat_models.fireworks",
|
||||||
"ChatFriendli": "langchain_community.chat_models.friendli",
|
"ChatFriendli": "langchain_community.chat_models.friendli",
|
||||||
"ChatGooglePalm": "langchain_community.chat_models.google_palm",
|
"ChatGooglePalm": "langchain_community.chat_models.google_palm",
|
||||||
|
@ -122,8 +122,8 @@ def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
|
|||||||
system = None
|
system = None
|
||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
|
|
||||||
human_messages = filter(lambda msg: isinstance(msg, HumanMessage), messages)
|
human_messages = list(filter(lambda msg: isinstance(msg, HumanMessage), messages))
|
||||||
last_human_message = list(human_messages)[-1] if human_messages else ""
|
last_human_message = human_messages[-1] if human_messages else ""
|
||||||
|
|
||||||
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
|
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
|
||||||
for i, message in enumerate(other_messages):
|
for i, message in enumerate(other_messages):
|
||||||
|
@ -3,17 +3,21 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||||
from langchain_core.pydantic_v1 import validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl, validator
|
||||||
|
|
||||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechToTextInput(BaseModel):
|
||||||
|
query: HttpUrl = Field(description="url of the audio to analyze")
|
||||||
|
|
||||||
|
|
||||||
class EdenAiSpeechToTextTool(EdenaiTool):
|
class EdenAiSpeechToTextTool(EdenaiTool):
|
||||||
"""Tool that queries the Eden AI Speech To Text API.
|
"""Tool that queries the Eden AI Speech To Text API.
|
||||||
|
|
||||||
@ -23,7 +27,6 @@ class EdenAiSpeechToTextTool(EdenaiTool):
|
|||||||
To use, you should have
|
To use, you should have
|
||||||
the environment variable ``EDENAI_API_KEY`` set with your API token.
|
the environment variable ``EDENAI_API_KEY`` set with your API token.
|
||||||
You can find your token here: https://app.edenai.run/admin/account/settings
|
You can find your token here: https://app.edenai.run/admin/account/settings
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
edenai_api_key: Optional[str] = None
|
edenai_api_key: Optional[str] = None
|
||||||
@ -34,6 +37,7 @@ class EdenAiSpeechToTextTool(EdenaiTool):
|
|||||||
"Useful for when you have to convert audio to text."
|
"Useful for when you have to convert audio to text."
|
||||||
"Input should be a url to an audio file."
|
"Input should be a url to an audio file."
|
||||||
)
|
)
|
||||||
|
args_schema: Type[BaseModel] = SpeechToTextInput
|
||||||
is_async: bool = True
|
is_async: bool = True
|
||||||
|
|
||||||
language: Optional[str] = "en"
|
language: Optional[str] = "en"
|
||||||
|
@ -1,17 +1,21 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Literal, Optional
|
from typing import Dict, List, Literal, Optional, Type
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
|
||||||
|
|
||||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechInput(BaseModel):
|
||||||
|
query: str = Field(description="text to generate audio from")
|
||||||
|
|
||||||
|
|
||||||
class EdenAiTextToSpeechTool(EdenaiTool):
|
class EdenAiTextToSpeechTool(EdenaiTool):
|
||||||
"""Tool that queries the Eden AI Text to speech API.
|
"""Tool that queries the Eden AI Text to speech API.
|
||||||
for api reference check edenai documentation:
|
for api reference check edenai documentation:
|
||||||
@ -30,6 +34,7 @@ class EdenAiTextToSpeechTool(EdenaiTool):
|
|||||||
"""the output is a string representing the URL of the audio file,
|
"""the output is a string representing the URL of the audio file,
|
||||||
or the path to the downloaded wav file """
|
or the path to the downloaded wav file """
|
||||||
)
|
)
|
||||||
|
args_schema: Type[BaseModel] = TextToSpeechInput
|
||||||
|
|
||||||
language: Optional[str] = "en"
|
language: Optional[str] = "en"
|
||||||
"""
|
"""
|
||||||
|
@ -1,15 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExplicitImageInput(BaseModel):
|
||||||
|
query: HttpUrl = Field(description="url of the image to analyze")
|
||||||
|
|
||||||
|
|
||||||
class EdenAiExplicitImageTool(EdenaiTool):
|
class EdenAiExplicitImageTool(EdenaiTool):
|
||||||
"""Tool that queries the Eden AI Explicit image detection.
|
"""Tool that queries the Eden AI Explicit image detection.
|
||||||
|
|
||||||
@ -33,6 +38,7 @@ class EdenAiExplicitImageTool(EdenaiTool):
|
|||||||
pornography, violence, gore content, etc."""
|
pornography, violence, gore content, etc."""
|
||||||
"Input should be the string url of the image ."
|
"Input should be the string url of the image ."
|
||||||
)
|
)
|
||||||
|
args_schema: Type[BaseModel] = ExplicitImageInput
|
||||||
|
|
||||||
combine_available: bool = True
|
combine_available: bool = True
|
||||||
feature: str = "image"
|
feature: str = "image"
|
||||||
|
@ -1,15 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetectionInput(BaseModel):
|
||||||
|
query: HttpUrl = Field(description="url of the image to analyze")
|
||||||
|
|
||||||
|
|
||||||
class EdenAiObjectDetectionTool(EdenaiTool):
|
class EdenAiObjectDetectionTool(EdenaiTool):
|
||||||
"""Tool that queries the Eden AI Object detection API.
|
"""Tool that queries the Eden AI Object detection API.
|
||||||
|
|
||||||
@ -30,6 +35,7 @@ class EdenAiObjectDetectionTool(EdenaiTool):
|
|||||||
(with bounding boxes) objects in an image """
|
(with bounding boxes) objects in an image """
|
||||||
"Input should be the string url of the image to identify."
|
"Input should be the string url of the image to identify."
|
||||||
)
|
)
|
||||||
|
args_schema: Type[BaseModel] = ObjectDetectionInput
|
||||||
|
|
||||||
show_positions: bool = False
|
show_positions: bool = False
|
||||||
|
|
||||||
|
@ -1,15 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IDParsingInput(BaseModel):
|
||||||
|
query: HttpUrl = Field(description="url of the document to parse")
|
||||||
|
|
||||||
|
|
||||||
class EdenAiParsingIDTool(EdenaiTool):
|
class EdenAiParsingIDTool(EdenaiTool):
|
||||||
"""Tool that queries the Eden AI Identity parsing API.
|
"""Tool that queries the Eden AI Identity parsing API.
|
||||||
|
|
||||||
@ -29,6 +34,7 @@ class EdenAiParsingIDTool(EdenaiTool):
|
|||||||
"Useful for when you have to extract information from an ID Document "
|
"Useful for when you have to extract information from an ID Document "
|
||||||
"Input should be the string url of the document to parse."
|
"Input should be the string url of the document to parse."
|
||||||
)
|
)
|
||||||
|
args_schema: Type[BaseModel] = IDParsingInput
|
||||||
|
|
||||||
feature: str = "ocr"
|
feature: str = "ocr"
|
||||||
subfeature: str = "identity_parser"
|
subfeature: str = "identity_parser"
|
||||||
|
@ -1,15 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InvoiceParsingInput(BaseModel):
|
||||||
|
query: HttpUrl = Field(description="url of the document to parse")
|
||||||
|
|
||||||
|
|
||||||
class EdenAiParsingInvoiceTool(EdenaiTool):
|
class EdenAiParsingInvoiceTool(EdenaiTool):
|
||||||
"""Tool that queries the Eden AI Invoice parsing API.
|
"""Tool that queries the Eden AI Invoice parsing API.
|
||||||
|
|
||||||
@ -23,7 +28,6 @@ class EdenAiParsingInvoiceTool(EdenaiTool):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "edenai_invoice_parsing"
|
name: str = "edenai_invoice_parsing"
|
||||||
|
|
||||||
description: str = (
|
description: str = (
|
||||||
"A wrapper around edenai Services invoice parsing. "
|
"A wrapper around edenai Services invoice parsing. "
|
||||||
"""Useful for when you have to extract information from
|
"""Useful for when you have to extract information from
|
||||||
@ -33,6 +37,7 @@ class EdenAiParsingInvoiceTool(EdenaiTool):
|
|||||||
in a structured format to automate the invoice processing """
|
in a structured format to automate the invoice processing """
|
||||||
"Input should be the string url of the document to parse."
|
"Input should be the string url of the document to parse."
|
||||||
)
|
)
|
||||||
|
args_schema: Type[BaseModel] = InvoiceParsingInput
|
||||||
|
|
||||||
language: Optional[str] = None
|
language: Optional[str] = None
|
||||||
"""
|
"""
|
||||||
|
@ -1,15 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TextModerationInput(BaseModel):
|
||||||
|
query: str = Field(description="Text to moderate")
|
||||||
|
|
||||||
|
|
||||||
class EdenAiTextModerationTool(EdenaiTool):
|
class EdenAiTextModerationTool(EdenaiTool):
|
||||||
"""Tool that queries the Eden AI Explicit text detection.
|
"""Tool that queries the Eden AI Explicit text detection.
|
||||||
|
|
||||||
@ -23,7 +28,6 @@ class EdenAiTextModerationTool(EdenaiTool):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "edenai_explicit_content_detection_text"
|
name: str = "edenai_explicit_content_detection_text"
|
||||||
|
|
||||||
description: str = (
|
description: str = (
|
||||||
"A wrapper around edenai Services explicit content detection for text. "
|
"A wrapper around edenai Services explicit content detection for text. "
|
||||||
"""Useful for when you have to scan text for offensive,
|
"""Useful for when you have to scan text for offensive,
|
||||||
@ -44,6 +48,7 @@ class EdenAiTextModerationTool(EdenaiTool):
|
|||||||
"""
|
"""
|
||||||
"Input should be a string."
|
"Input should be a string."
|
||||||
)
|
)
|
||||||
|
args_schema: Type[BaseModel] = TextModerationInput
|
||||||
|
|
||||||
language: str
|
language: str
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ EXPECTED_ALL = [
|
|||||||
"ChatDatabricks",
|
"ChatDatabricks",
|
||||||
"ChatDeepInfra",
|
"ChatDeepInfra",
|
||||||
"ChatEverlyAI",
|
"ChatEverlyAI",
|
||||||
|
"ChatEdenAI",
|
||||||
"ChatFireworks",
|
"ChatFireworks",
|
||||||
"ChatFriendli",
|
"ChatFriendli",
|
||||||
"ChatGooglePalm",
|
"ChatGooglePalm",
|
||||||
|
Loading…
Reference in New Issue
Block a user