mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
update OpenAI function agents' llm validation (#13538)
- **Description:** This PR modifies the LLM validation in OpenAI function agents to check whether the LLM supports OpenAI functions based on a property (`supports_oia_functions`) instead of whether the LLM passed to the agent `isinstance` of `ChatOpenAI`. This allows classes that extend `BaseChatModel` to be passed to these agents as long as they've been integrated with the OpenAI APIs and have this property set, even if they don't extend `ChatOpenAI`. - **Issue:** N/A - **Dependencies:** none
This commit is contained in:
parent
74c7b799ef
commit
5cb3393e20
@ -10,7 +10,6 @@ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
|||||||
AgentTokenBufferMemory,
|
AgentTokenBufferMemory,
|
||||||
)
|
)
|
||||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
|
||||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
@ -57,8 +56,6 @@ def create_conversational_retrieval_agent(
|
|||||||
An agent executor initialized appropriately
|
An agent executor initialized appropriately
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(llm, ChatOpenAI):
|
|
||||||
raise ValueError("Only supported with ChatOpenAI models.")
|
|
||||||
if remember_intermediate_steps:
|
if remember_intermediate_steps:
|
||||||
memory: BaseMemory = AgentTokenBufferMemory(
|
memory: BaseMemory = AgentTokenBufferMemory(
|
||||||
memory_key=memory_key, llm=llm, max_token_limit=max_token_limit
|
memory_key=memory_key, llm=llm, max_token_limit=max_token_limit
|
||||||
|
@ -25,7 +25,6 @@ from langchain.agents.output_parsers.openai_functions import (
|
|||||||
)
|
)
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.tools.render import format_tool_to_openai_function
|
from langchain.tools.render import format_tool_to_openai_function
|
||||||
|
|
||||||
@ -50,12 +49,6 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
"""Get allowed tools."""
|
"""Get allowed tools."""
|
||||||
return [t.name for t in self.tools]
|
return [t.name for t in self.tools]
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_llm(cls, values: dict) -> dict:
|
|
||||||
if not isinstance(values["llm"], ChatOpenAI):
|
|
||||||
raise ValueError("Only supported with ChatOpenAI models.")
|
|
||||||
return values
|
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_prompt(cls, values: dict) -> dict:
|
def validate_prompt(cls, values: dict) -> dict:
|
||||||
prompt: BasePromptTemplate = values["prompt"]
|
prompt: BasePromptTemplate = values["prompt"]
|
||||||
@ -222,8 +215,6 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseSingleActionAgent:
|
) -> BaseSingleActionAgent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
if not isinstance(llm, ChatOpenAI):
|
|
||||||
raise ValueError("Only supported with ChatOpenAI models.")
|
|
||||||
prompt = cls.create_prompt(
|
prompt = cls.create_prompt(
|
||||||
extra_prompt_messages=extra_prompt_messages,
|
extra_prompt_messages=extra_prompt_messages,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
|
@ -26,7 +26,6 @@ from langchain.agents.format_scratchpad.openai_functions import (
|
|||||||
)
|
)
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
# For backwards compatibility
|
# For backwards compatibility
|
||||||
@ -109,12 +108,6 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
"""Get allowed tools."""
|
"""Get allowed tools."""
|
||||||
return [t.name for t in self.tools]
|
return [t.name for t in self.tools]
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_llm(cls, values: dict) -> dict:
|
|
||||||
if not isinstance(values["llm"], ChatOpenAI):
|
|
||||||
raise ValueError("Only supported with ChatOpenAI models.")
|
|
||||||
return values
|
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_prompt(cls, values: dict) -> dict:
|
def validate_prompt(cls, values: dict) -> dict:
|
||||||
prompt: BasePromptTemplate = values["prompt"]
|
prompt: BasePromptTemplate = values["prompt"]
|
||||||
|
Loading…
Reference in New Issue
Block a user