From ec4dab04498ed25f89a33d3710a3e42793caeb62 Mon Sep 17 00:00:00 2001 From: Massimiliano Pronesti Date: Fri, 12 Jan 2024 06:32:03 +0100 Subject: [PATCH] feat(community): make Amadeus toolkit LLM-agnostic (#15879) - **Description:** `AmadeusToolkit` and `AmadeusClosestAirport` contained a hardcoded call to `ChatOpenAI`. This PR makes it LLM-independent, while guaranteeing backward compatibility. - **Issue:** #15847 - **Dependencies:** None @baskaryan --- .../agent_toolkits/amadeus/toolkit.py | 6 ++++-- .../tools/amadeus/closest_airport.py | 17 ++++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py b/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py index 90bc5da6476..b4d59a96e92 100644 --- a/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py +++ b/libs/community/langchain_community/agent_toolkits/amadeus/toolkit.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional +from langchain_core.language_models import BaseLanguageModel from langchain_core.pydantic_v1 import Field from langchain_community.agent_toolkits.base import BaseToolkit @@ -18,6 +19,7 @@ class AmadeusToolkit(BaseToolkit): """Toolkit for interacting with Amadeus which offers APIs for travel.""" client: Client = Field(default_factory=authenticate) + llm: Optional[BaseLanguageModel] = Field(default=None) class Config: """Pydantic config.""" @@ -27,6 +29,6 @@ class AmadeusToolkit(BaseToolkit): def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" return [ - AmadeusClosestAirport(), + AmadeusClosestAirport(llm=self.llm), AmadeusFlightSearch(), ] diff --git a/libs/community/langchain_community/tools/amadeus/closest_airport.py b/libs/community/langchain_community/tools/amadeus/closest_airport.py index 4e8b90a1b2a..4107cc500ba 100644 --- a/libs/community/langchain_community/tools/amadeus/closest_airport.py +++ b/libs/community/langchain_community/tools/amadeus/closest_airport.py @@ -1,7 +1,8 @@ -from typing import Optional, Type +from typing import Any, Dict, Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.language_models import BaseLanguageModel +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_community.chat_models import ChatOpenAI from langchain_community.tools.amadeus.base import AmadeusBaseTool @@ -35,6 +36,16 @@ class AmadeusClosestAirport(AmadeusBaseTool): ) args_schema: Type[ClosestAirportSchema] = ClosestAirportSchema + llm: Optional[BaseLanguageModel] = Field(default=None) + """Tool's llm used for calculating the closest airport. Defaults to `ChatOpenAI`.""" + + @root_validator(pre=True) + def set_llm(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if not values.get("llm"): + # For backward-compatibility + values["llm"] = ChatOpenAI(temperature=0) + return values + def _run( self, location: str, @@ -47,4 +58,4 @@ class AmadeusClosestAirport(AmadeusBaseTool): ' Location Identifier" ' ) - return ChatOpenAI(temperature=0).predict(content) + return self.llm.predict(content)