mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 10:13:29 +00:00
pass callbacks along baby ai (#7908)
This commit is contained in:
parent
a4c5914c9a
commit
df84e1bb64
@ -60,7 +60,7 @@ class BabyAGI(Chain, BaseModel):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def get_next_task(
|
def get_next_task(
|
||||||
self, result: str, task_description: str, objective: str
|
self, result: str, task_description: str, objective: str, **kwargs: Any
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""Get the next task."""
|
"""Get the next task."""
|
||||||
task_names = [t["task_name"] for t in self.task_list]
|
task_names = [t["task_name"] for t in self.task_list]
|
||||||
@ -71,13 +71,16 @@ class BabyAGI(Chain, BaseModel):
|
|||||||
task_description=task_description,
|
task_description=task_description,
|
||||||
incomplete_tasks=incomplete_tasks,
|
incomplete_tasks=incomplete_tasks,
|
||||||
objective=objective,
|
objective=objective,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
new_tasks = response.split("\n")
|
new_tasks = response.split("\n")
|
||||||
return [
|
return [
|
||||||
{"task_name": task_name} for task_name in new_tasks if task_name.strip()
|
{"task_name": task_name} for task_name in new_tasks if task_name.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
def prioritize_tasks(self, this_task_id: int, objective: str) -> List[Dict]:
|
def prioritize_tasks(
|
||||||
|
self, this_task_id: int, objective: str, **kwargs: Any
|
||||||
|
) -> List[Dict]:
|
||||||
"""Prioritize tasks."""
|
"""Prioritize tasks."""
|
||||||
task_names = [t["task_name"] for t in list(self.task_list)]
|
task_names = [t["task_name"] for t in list(self.task_list)]
|
||||||
next_task_id = int(this_task_id) + 1
|
next_task_id = int(this_task_id) + 1
|
||||||
@ -85,6 +88,7 @@ class BabyAGI(Chain, BaseModel):
|
|||||||
task_names=", ".join(task_names),
|
task_names=", ".join(task_names),
|
||||||
next_task_id=str(next_task_id),
|
next_task_id=str(next_task_id),
|
||||||
objective=objective,
|
objective=objective,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
new_tasks = response.split("\n")
|
new_tasks = response.split("\n")
|
||||||
prioritized_task_list = []
|
prioritized_task_list = []
|
||||||
@ -107,11 +111,11 @@ class BabyAGI(Chain, BaseModel):
|
|||||||
return []
|
return []
|
||||||
return [str(item.metadata["task"]) for item in results]
|
return [str(item.metadata["task"]) for item in results]
|
||||||
|
|
||||||
def execute_task(self, objective: str, task: str, k: int = 5) -> str:
|
def execute_task(self, objective: str, task: str, k: int = 5, **kwargs: Any) -> str:
|
||||||
"""Execute a task."""
|
"""Execute a task."""
|
||||||
context = self._get_top_tasks(query=objective, k=k)
|
context = self._get_top_tasks(query=objective, k=k)
|
||||||
return self.execution_chain.run(
|
return self.execution_chain.run(
|
||||||
objective=objective, context="\n".join(context), task=task
|
objective=objective, context="\n".join(context), task=task, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
@ -120,6 +124,7 @@ class BabyAGI(Chain, BaseModel):
|
|||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Run the agent."""
|
"""Run the agent."""
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
objective = inputs["objective"]
|
objective = inputs["objective"]
|
||||||
first_task = inputs.get("first_task", "Make a todo list")
|
first_task = inputs.get("first_task", "Make a todo list")
|
||||||
self.add_task({"task_id": 1, "task_name": first_task})
|
self.add_task({"task_id": 1, "task_name": first_task})
|
||||||
@ -133,7 +138,9 @@ class BabyAGI(Chain, BaseModel):
|
|||||||
self.print_next_task(task)
|
self.print_next_task(task)
|
||||||
|
|
||||||
# Step 2: Execute the task
|
# Step 2: Execute the task
|
||||||
result = self.execute_task(objective, task["task_name"])
|
result = self.execute_task(
|
||||||
|
objective, task["task_name"], callbacks=_run_manager.get_child()
|
||||||
|
)
|
||||||
this_task_id = int(task["task_id"])
|
this_task_id = int(task["task_id"])
|
||||||
self.print_task_result(result)
|
self.print_task_result(result)
|
||||||
|
|
||||||
@ -146,12 +153,21 @@ class BabyAGI(Chain, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Step 4: Create new tasks and reprioritize task list
|
# Step 4: Create new tasks and reprioritize task list
|
||||||
new_tasks = self.get_next_task(result, task["task_name"], objective)
|
new_tasks = self.get_next_task(
|
||||||
|
result,
|
||||||
|
task["task_name"],
|
||||||
|
objective,
|
||||||
|
callbacks=_run_manager.get_child(),
|
||||||
|
)
|
||||||
for new_task in new_tasks:
|
for new_task in new_tasks:
|
||||||
self.task_id_counter += 1
|
self.task_id_counter += 1
|
||||||
new_task.update({"task_id": self.task_id_counter})
|
new_task.update({"task_id": self.task_id_counter})
|
||||||
self.add_task(new_task)
|
self.add_task(new_task)
|
||||||
self.task_list = deque(self.prioritize_tasks(this_task_id, objective))
|
self.task_list = deque(
|
||||||
|
self.prioritize_tasks(
|
||||||
|
this_task_id, objective, callbacks=_run_manager.get_child()
|
||||||
|
)
|
||||||
|
)
|
||||||
num_iters += 1
|
num_iters += 1
|
||||||
if self.max_iterations is not None and num_iters == self.max_iterations:
|
if self.max_iterations is not None and num_iters == self.max_iterations:
|
||||||
print(
|
print(
|
||||||
|
Loading…
Reference in New Issue
Block a user