mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 08:40:36 +00:00
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
from typing import List, Optional
|
|
|
|
from dbgpt.agent.core.memory.gpts import (
|
|
GptsMessage,
|
|
GptsMessageMemory,
|
|
GptsPlan,
|
|
GptsPlansMemory,
|
|
)
|
|
|
|
from ..db.gpts_messages_db import GptsMessagesDao
|
|
from ..db.gpts_plans_db import GptsPlansDao, GptsPlansEntity
|
|
|
|
|
|
class MetaDbGptsPlansMemory(GptsPlansMemory):
|
|
def __init__(self):
|
|
self.gpts_plan = GptsPlansDao()
|
|
|
|
def batch_save(self, plans: List[GptsPlan]):
|
|
self.gpts_plan.batch_save([item.to_dict() for item in plans])
|
|
|
|
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
|
|
db_results: List[GptsPlansEntity] = self.gpts_plan.get_by_conv_id(
|
|
conv_id=conv_id
|
|
)
|
|
results = []
|
|
for item in db_results:
|
|
results.append(GptsPlan.from_dict(item.__dict__))
|
|
return results
|
|
|
|
def get_by_conv_id_and_num(
|
|
self, conv_id: str, task_nums: List[int]
|
|
) -> List[GptsPlan]:
|
|
db_results: List[GptsPlansEntity] = self.gpts_plan.get_by_conv_id_and_num(
|
|
conv_id=conv_id, task_nums=task_nums
|
|
)
|
|
results = []
|
|
for item in db_results:
|
|
results.append(GptsPlan.from_dict(item.__dict__))
|
|
return results
|
|
|
|
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
|
|
db_results: List[GptsPlansEntity] = self.gpts_plan.get_todo_plans(
|
|
conv_id=conv_id
|
|
)
|
|
results = []
|
|
for item in db_results:
|
|
results.append(GptsPlan.from_dict(item.__dict__))
|
|
return results
|
|
|
|
def complete_task(self, conv_id: str, task_num: int, result: str):
|
|
self.gpts_plan.complete_task(conv_id=conv_id, task_num=task_num, result=result)
|
|
|
|
def update_task(
|
|
self,
|
|
conv_id: str,
|
|
task_num: int,
|
|
state: str,
|
|
retry_times: int,
|
|
agent: str = None,
|
|
model: str = None,
|
|
result: str = None,
|
|
):
|
|
self.gpts_plan.update_task(
|
|
conv_id=conv_id,
|
|
task_num=task_num,
|
|
state=state,
|
|
retry_times=retry_times,
|
|
agent=agent,
|
|
model=model,
|
|
result=result,
|
|
)
|
|
|
|
def remove_by_conv_id(self, conv_id: str):
|
|
self.gpts_plan.remove_by_conv_id(conv_id=conv_id)
|
|
|
|
|
|
class MetaDbGptsMessageMemory(GptsMessageMemory):
|
|
def __init__(self):
|
|
self.gpts_message = GptsMessagesDao()
|
|
|
|
def append(self, message: GptsMessage):
|
|
self.gpts_message.append(message.to_dict())
|
|
|
|
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
|
|
db_results = self.gpts_message.get_by_agent(conv_id, agent)
|
|
results = []
|
|
db_results = sorted(db_results, key=lambda x: x.rounds)
|
|
for item in db_results:
|
|
results.append(GptsMessage.from_dict(item.__dict__))
|
|
return results
|
|
|
|
def get_between_agents(
|
|
self,
|
|
conv_id: str,
|
|
agent1: str,
|
|
agent2: str,
|
|
current_goal: Optional[str] = None,
|
|
) -> Optional[List[GptsMessage]]:
|
|
db_results = self.gpts_message.get_between_agents(
|
|
conv_id, agent1, agent2, current_goal
|
|
)
|
|
results = []
|
|
db_results = sorted(db_results, key=lambda x: x.rounds)
|
|
for item in db_results:
|
|
results.append(GptsMessage.from_dict(item.__dict__))
|
|
return results
|
|
|
|
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessage]]:
|
|
db_results = self.gpts_message.get_by_conv_id(conv_id)
|
|
|
|
results = []
|
|
db_results = sorted(db_results, key=lambda x: x.rounds)
|
|
for item in db_results:
|
|
results.append(GptsMessage.from_dict(item.__dict__))
|
|
return results
|
|
|
|
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
|
|
db_result = self.gpts_message.get_last_message(conv_id)
|
|
if db_result:
|
|
return GptsMessage.from_dict(db_result.__dict__)
|
|
else:
|
|
return None
|