mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
update OpenAI function agents' llm validation (#13538)
- **Description:** This PR modifies the LLM validation in OpenAI function agents to check whether the LLM supports OpenAI functions based on a property (`supports_oia_functions`) instead of whether the LLM passed to the agent `isinstance` of `ChatOpenAI`. This allows classes that extend `BaseChatModel` to be passed to these agents as long as they've been integrated with the OpenAI APIs and have this property set, even if they don't extend `ChatOpenAI`. - **Issue:** N/A - **Dependencies:** none
This commit is contained in:
parent
74c7b799ef
commit
5cb3393e20
@ -10,7 +10,6 @@ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
||||
AgentTokenBufferMemory,
|
||||
)
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
@ -57,8 +56,6 @@ def create_conversational_retrieval_agent(
|
||||
An agent executor initialized appropriately
|
||||
"""
|
||||
|
||||
if not isinstance(llm, ChatOpenAI):
|
||||
raise ValueError("Only supported with ChatOpenAI models.")
|
||||
if remember_intermediate_steps:
|
||||
memory: BaseMemory = AgentTokenBufferMemory(
|
||||
memory_key=memory_key, llm=llm, max_token_limit=max_token_limit
|
||||
|
@ -25,7 +25,6 @@ from langchain.agents.output_parsers.openai_functions import (
|
||||
)
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.render import format_tool_to_openai_function
|
||||
|
||||
@ -50,12 +49,6 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
"""Get allowed tools."""
|
||||
return [t.name for t in self.tools]
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
if not isinstance(values["llm"], ChatOpenAI):
|
||||
raise ValueError("Only supported with ChatOpenAI models.")
|
||||
return values
|
||||
|
||||
@root_validator
|
||||
def validate_prompt(cls, values: dict) -> dict:
|
||||
prompt: BasePromptTemplate = values["prompt"]
|
||||
@ -222,8 +215,6 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
if not isinstance(llm, ChatOpenAI):
|
||||
raise ValueError("Only supported with ChatOpenAI models.")
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
|
@ -26,7 +26,6 @@ from langchain.agents.format_scratchpad.openai_functions import (
|
||||
)
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
# For backwards compatibility
|
||||
@ -109,12 +108,6 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
"""Get allowed tools."""
|
||||
return [t.name for t in self.tools]
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
if not isinstance(values["llm"], ChatOpenAI):
|
||||
raise ValueError("Only supported with ChatOpenAI models.")
|
||||
return values
|
||||
|
||||
@root_validator
|
||||
def validate_prompt(cls, values: dict) -> dict:
|
||||
prompt: BasePromptTemplate = values["prompt"]
|
||||
|
Loading…
Reference in New Issue
Block a user