update OpenAI function agents' llm validation (#13538)

- **Description:** This PR modifies the LLM validation in OpenAI
function agents to check whether the LLM supports OpenAI functions based
on a property (`supports_oia_functions`) instead of whether the LLM
passed to the agent `isinstance` of `ChatOpenAI`. This allows classes
that extend `BaseChatModel` to be passed to these agents as long as
they've been integrated with the OpenAI APIs and have this property set,
even if they don't extend `ChatOpenAI`.
  - **Issue:** N/A
  - **Dependencies:** none
This commit is contained in:
price-deshaw 2023-12-04 23:28:13 -05:00 committed by GitHub
parent 74c7b799ef
commit 5cb3393e20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 0 additions and 19 deletions

View File

@ -10,7 +10,6 @@ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory, AgentTokenBufferMemory,
) )
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.chat_models.openai import ChatOpenAI
from langchain.memory.token_buffer import ConversationTokenBufferMemory from langchain.memory.token_buffer import ConversationTokenBufferMemory
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
@ -57,8 +56,6 @@ def create_conversational_retrieval_agent(
An agent executor initialized appropriately An agent executor initialized appropriately
""" """
if not isinstance(llm, ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
if remember_intermediate_steps: if remember_intermediate_steps:
memory: BaseMemory = AgentTokenBufferMemory( memory: BaseMemory = AgentTokenBufferMemory(
memory_key=memory_key, llm=llm, max_token_limit=max_token_limit memory_key=memory_key, llm=llm, max_token_limit=max_token_limit

View File

@ -25,7 +25,6 @@ from langchain.agents.output_parsers.openai_functions import (
) )
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_function from langchain.tools.render import format_tool_to_openai_function
@ -50,12 +49,6 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""Get allowed tools.""" """Get allowed tools."""
return [t.name for t in self.tools] return [t.name for t in self.tools]
@root_validator
def validate_llm(cls, values: dict) -> dict:
if not isinstance(values["llm"], ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
return values
@root_validator @root_validator
def validate_prompt(cls, values: dict) -> dict: def validate_prompt(cls, values: dict) -> dict:
prompt: BasePromptTemplate = values["prompt"] prompt: BasePromptTemplate = values["prompt"]
@ -222,8 +215,6 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
**kwargs: Any, **kwargs: Any,
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
if not isinstance(llm, ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
prompt = cls.create_prompt( prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages, extra_prompt_messages=extra_prompt_messages,
system_message=system_message, system_message=system_message,

View File

@ -26,7 +26,6 @@ from langchain.agents.format_scratchpad.openai_functions import (
) )
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI
from langchain.tools import BaseTool from langchain.tools import BaseTool
# For backwards compatibility # For backwards compatibility
@ -109,12 +108,6 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
"""Get allowed tools.""" """Get allowed tools."""
return [t.name for t in self.tools] return [t.name for t in self.tools]
@root_validator
def validate_llm(cls, values: dict) -> dict:
if not isinstance(values["llm"], ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
return values
@root_validator @root_validator
def validate_prompt(cls, values: dict) -> dict: def validate_prompt(cls, values: dict) -> dict:
prompt: BasePromptTemplate = values["prompt"] prompt: BasePromptTemplate = values["prompt"]