Files
DB-GPT/dbgpt/agent/memory/default_gpts_memory.py
2024-02-20 14:28:58 +08:00

126 lines
4.6 KiB
Python

from dataclasses import fields
from typing import List, Optional
import pandas as pd
from dbgpt.agent.common.schema import Status
from .base import GptsMessage, GptsMessageMemory, GptsPlan, GptsPlansMemory
class DefaultGptsPlansMemory(GptsPlansMemory):
def __init__(self):
self.df = pd.DataFrame(columns=[field.name for field in fields(GptsPlan)])
def batch_save(self, plans: list[GptsPlan]):
new_rows = pd.DataFrame([item.to_dict() for item in plans])
self.df = pd.concat([self.df, new_rows], ignore_index=True)
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
result = self.df.query(f"conv_id==@conv_id")
plans = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
plans.append(GptsPlan.from_dict(row_dict))
return plans
def get_by_conv_id_and_num(
self, conv_id: str, task_nums: List[int]
) -> List[GptsPlan]:
task_nums_int = [int(num) for num in task_nums]
result = self.df.query(f"conv_id==@conv_id and sub_task_num in @task_nums_int")
plans = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
plans.append(GptsPlan.from_dict(row_dict))
return plans
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
todo_states = [Status.TODO.value, Status.RETRYING.value]
result = self.df.query(f"conv_id==@conv_id and state in @todo_states")
plans = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
plans.append(GptsPlan.from_dict(row_dict))
return plans
def complete_task(self, conv_id: str, task_num: int, result: str):
condition = (self.df["conv_id"] == conv_id) & (
self.df["sub_task_num"] == task_num
)
self.df.loc[condition, "state"] = Status.COMPLETE.value
self.df.loc[condition, "result"] = result
def update_task(
self,
conv_id: str,
task_num: int,
state: str,
retry_times: int,
agent: str = None,
model=None,
result: str = None,
):
condition = (self.df["conv_id"] == conv_id) & (
self.df["sub_task_num"] == task_num
)
self.df.loc[condition, "state"] = state
self.df.loc[condition, "retry_times"] = retry_times
self.df.loc[condition, "result"] = result
if agent:
self.df.loc[condition, "sub_task_agent"] = agent
if model:
self.df.loc[condition, "agent_model"] = model
def remove_by_conv_id(self, conv_id: str):
self.df.drop(self.df[self.df["conv_id"] == conv_id].index, inplace=True)
class DefaultGptsMessageMemory(GptsMessageMemory):
def __init__(self):
self.df = pd.DataFrame(columns=[field.name for field in fields(GptsMessage)])
def append(self, message: GptsMessage):
self.df.loc[len(self.df)] = message.to_dict()
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
result = self.df.query(
f"conv_id==@conv_id and (sender==@agent or receiver==@agent)"
)
messages = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
messages.append(GptsMessage.from_dict(row_dict))
return messages
def get_between_agents(
self,
conv_id: str,
agent1: str,
agent2: str,
current_goal: Optional[str] = None,
) -> Optional[List[GptsMessage]]:
if current_goal:
result = self.df.query(
f"conv_id==@conv_id and ((sender==@agent1 and receiver==@agent2) or (sender==@agent2 and receiver==@agent1)) and current_goal==@current_goal"
)
else:
result = self.df.query(
f"conv_id==@conv_id and ((sender==@agent1 and receiver==@agent2) or (sender==@agent2 and receiver==@agent1))"
)
messages = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
messages.append(GptsMessage.from_dict(row_dict))
return messages
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessage]]:
result = self.df.query(f"conv_id==@conv_id")
messages = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
messages.append(GptsMessage.from_dict(row_dict))
return messages