From df84e1bb64d96377f909651f696f310c43c2f2c5 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 19 Jul 2023 22:40:33 -0700 Subject: [PATCH] pass callbacks along baby ai (#7908) --- .../autonomous_agents/baby_agi/baby_agi.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py b/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py index 446a676cac8..ab11b2bd85a 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py +++ b/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py @@ -60,7 +60,7 @@ class BabyAGI(Chain, BaseModel): return [] def get_next_task( - self, result: str, task_description: str, objective: str + self, result: str, task_description: str, objective: str, **kwargs: Any ) -> List[Dict]: """Get the next task.""" task_names = [t["task_name"] for t in self.task_list] @@ -71,13 +71,16 @@ class BabyAGI(Chain, BaseModel): task_description=task_description, incomplete_tasks=incomplete_tasks, objective=objective, + **kwargs, ) new_tasks = response.split("\n") return [ {"task_name": task_name} for task_name in new_tasks if task_name.strip() ] - def prioritize_tasks(self, this_task_id: int, objective: str) -> List[Dict]: + def prioritize_tasks( + self, this_task_id: int, objective: str, **kwargs: Any + ) -> List[Dict]: """Prioritize tasks.""" task_names = [t["task_name"] for t in list(self.task_list)] next_task_id = int(this_task_id) + 1 @@ -85,6 +88,7 @@ class BabyAGI(Chain, BaseModel): task_names=", ".join(task_names), next_task_id=str(next_task_id), objective=objective, + **kwargs, ) new_tasks = response.split("\n") prioritized_task_list = [] @@ -107,11 +111,11 @@ class BabyAGI(Chain, BaseModel): return [] return [str(item.metadata["task"]) for item in results] - def execute_task(self, objective: str, task: str, k: int = 5) -> str: + def execute_task(self, objective: str, task: str, k: int = 5, **kwargs: Any) -> str: """Execute a task.""" context = self._get_top_tasks(query=objective, k=k) return self.execution_chain.run( - objective=objective, context="\n".join(context), task=task + objective=objective, context="\n".join(context), task=task, **kwargs ) def _call( @@ -120,6 +124,7 @@ class BabyAGI(Chain, BaseModel): run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Run the agent.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() objective = inputs["objective"] first_task = inputs.get("first_task", "Make a todo list") self.add_task({"task_id": 1, "task_name": first_task}) @@ -133,7 +138,9 @@ class BabyAGI(Chain, BaseModel): self.print_next_task(task) # Step 2: Execute the task - result = self.execute_task(objective, task["task_name"]) + result = self.execute_task( + objective, task["task_name"], callbacks=_run_manager.get_child() + ) this_task_id = int(task["task_id"]) self.print_task_result(result) @@ -146,12 +153,21 @@ class BabyAGI(Chain, BaseModel): ) # Step 4: Create new tasks and reprioritize task list - new_tasks = self.get_next_task(result, task["task_name"], objective) + new_tasks = self.get_next_task( + result, + task["task_name"], + objective, + callbacks=_run_manager.get_child(), + ) for new_task in new_tasks: self.task_id_counter += 1 new_task.update({"task_id": self.task_id_counter}) self.add_task(new_task) - self.task_list = deque(self.prioritize_tasks(this_task_id, objective)) + self.task_list = deque( + self.prioritize_tasks( + this_task_id, objective, callbacks=_run_manager.get_child() + ) + ) num_iters += 1 if self.max_iterations is not None and num_iters == self.max_iterations: print(