Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
d9d8cee8b6 rfc: agent initial tool choices 2024-01-17 19:52:27 -08:00
2 changed files with 28 additions and 4 deletions

View File

@@ -925,6 +925,7 @@ class AgentExecutor(Chain):
trim_intermediate_steps: Union[
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]]
] = -1
initial_tool_choices: Sequence[Union[str, BaseTool]] = ()
@classmethod
def from_agent_and_tools(
@@ -954,6 +955,10 @@ class AgentExecutor(Chain):
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
)
if values["max_iterations"] is not None:
values["max_iterations"] = max(
values["max_iterations"], len(values["initial_tool_choices"])
)
return values
@root_validator()
@@ -1373,10 +1378,19 @@ class AgentExecutor(Chain):
start_time = time.time()
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations, time_elapsed):
curr_inputs = inputs.copy()
if len(self.initial_tool_choices) > iterations:
tool_choice = self.initial_tool_choices[iterations]
tool_choice = (
tool_choice.name
if isinstance(tool_choice, BaseTool)
else tool_choice
)
curr_inputs["tool_choice"] = tool_choice
next_step_output = self._take_next_step(
name_to_tool_map,
color_mapping,
inputs,
curr_inputs,
intermediate_steps,
run_manager=run_manager,
)

View File

@@ -18,7 +18,7 @@ from langchain_core.prompts.chat import (
MessagesPlaceholder,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain.agents import BaseSingleActionAgent
@@ -229,6 +229,15 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
)
def _configure_prompt_llm(
inputs: dict, prompt: BasePromptTemplate, llm_with_tools: Runnable
) -> Runnable:
tool_choice = inputs.pop("tool_choice")
if tool_choice:
llm_with_tools.bind(functional_call={"name": tool_choice})
return prompt | llm_with_tools
def create_openai_functions_agent(
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
) -> Runnable:
@@ -307,8 +316,9 @@ def create_openai_functions_agent(
x["intermediate_steps"]
)
)
| prompt
| llm_with_tools
| RunnableLambda(_configure_prompt_llm).bind(
prompt=prompt, llm_with_tools=llm_with_tools
)
| OpenAIFunctionsAgentOutputParser()
)
return agent