pass callbacks along baby ai (#7908)

This commit is contained in:
Harrison Chase 2023-07-19 22:40:33 -07:00 committed by GitHub
parent a4c5914c9a
commit df84e1bb64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -60,7 +60,7 @@ class BabyAGI(Chain, BaseModel):
return [] return []
def get_next_task( def get_next_task(
self, result: str, task_description: str, objective: str self, result: str, task_description: str, objective: str, **kwargs: Any
) -> List[Dict]: ) -> List[Dict]:
"""Get the next task.""" """Get the next task."""
task_names = [t["task_name"] for t in self.task_list] task_names = [t["task_name"] for t in self.task_list]
@ -71,13 +71,16 @@ class BabyAGI(Chain, BaseModel):
task_description=task_description, task_description=task_description,
incomplete_tasks=incomplete_tasks, incomplete_tasks=incomplete_tasks,
objective=objective, objective=objective,
**kwargs,
) )
new_tasks = response.split("\n") new_tasks = response.split("\n")
return [ return [
{"task_name": task_name} for task_name in new_tasks if task_name.strip() {"task_name": task_name} for task_name in new_tasks if task_name.strip()
] ]
def prioritize_tasks(self, this_task_id: int, objective: str) -> List[Dict]: def prioritize_tasks(
self, this_task_id: int, objective: str, **kwargs: Any
) -> List[Dict]:
"""Prioritize tasks.""" """Prioritize tasks."""
task_names = [t["task_name"] for t in list(self.task_list)] task_names = [t["task_name"] for t in list(self.task_list)]
next_task_id = int(this_task_id) + 1 next_task_id = int(this_task_id) + 1
@ -85,6 +88,7 @@ class BabyAGI(Chain, BaseModel):
task_names=", ".join(task_names), task_names=", ".join(task_names),
next_task_id=str(next_task_id), next_task_id=str(next_task_id),
objective=objective, objective=objective,
**kwargs,
) )
new_tasks = response.split("\n") new_tasks = response.split("\n")
prioritized_task_list = [] prioritized_task_list = []
@ -107,11 +111,11 @@ class BabyAGI(Chain, BaseModel):
return [] return []
return [str(item.metadata["task"]) for item in results] return [str(item.metadata["task"]) for item in results]
def execute_task(self, objective: str, task: str, k: int = 5) -> str: def execute_task(self, objective: str, task: str, k: int = 5, **kwargs: Any) -> str:
"""Execute a task.""" """Execute a task."""
context = self._get_top_tasks(query=objective, k=k) context = self._get_top_tasks(query=objective, k=k)
return self.execution_chain.run( return self.execution_chain.run(
objective=objective, context="\n".join(context), task=task objective=objective, context="\n".join(context), task=task, **kwargs
) )
def _call( def _call(
@ -120,6 +124,7 @@ class BabyAGI(Chain, BaseModel):
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the agent.""" """Run the agent."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
objective = inputs["objective"] objective = inputs["objective"]
first_task = inputs.get("first_task", "Make a todo list") first_task = inputs.get("first_task", "Make a todo list")
self.add_task({"task_id": 1, "task_name": first_task}) self.add_task({"task_id": 1, "task_name": first_task})
@ -133,7 +138,9 @@ class BabyAGI(Chain, BaseModel):
self.print_next_task(task) self.print_next_task(task)
# Step 2: Execute the task # Step 2: Execute the task
result = self.execute_task(objective, task["task_name"]) result = self.execute_task(
objective, task["task_name"], callbacks=_run_manager.get_child()
)
this_task_id = int(task["task_id"]) this_task_id = int(task["task_id"])
self.print_task_result(result) self.print_task_result(result)
@ -146,12 +153,21 @@ class BabyAGI(Chain, BaseModel):
) )
# Step 4: Create new tasks and reprioritize task list # Step 4: Create new tasks and reprioritize task list
new_tasks = self.get_next_task(result, task["task_name"], objective) new_tasks = self.get_next_task(
result,
task["task_name"],
objective,
callbacks=_run_manager.get_child(),
)
for new_task in new_tasks: for new_task in new_tasks:
self.task_id_counter += 1 self.task_id_counter += 1
new_task.update({"task_id": self.task_id_counter}) new_task.update({"task_id": self.task_id_counter})
self.add_task(new_task) self.add_task(new_task)
self.task_list = deque(self.prioritize_tasks(this_task_id, objective)) self.task_list = deque(
self.prioritize_tasks(
this_task_id, objective, callbacks=_run_manager.get_child()
)
)
num_iters += 1 num_iters += 1
if self.max_iterations is not None and num_iters == self.max_iterations: if self.max_iterations is not None and num_iters == self.max_iterations:
print( print(