mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-21 17:37:52 +00:00
feat(agent): More general ReAct Agent (#2556)
This commit is contained in:
@@ -15,7 +15,9 @@ from dbgpt.agent import (
|
||||
AutoPlanChatManager,
|
||||
ConversableAgent,
|
||||
DefaultAWELLayoutManager,
|
||||
EnhancedShortTermMemory,
|
||||
GptsMemory,
|
||||
HybridMemory,
|
||||
LLMConfig,
|
||||
ResourceType,
|
||||
UserProxyAgent,
|
||||
@@ -31,6 +33,7 @@ from dbgpt.core.awel.flow.flow_factory import FlowCategory
|
||||
from dbgpt.core.interface.message import StorageConversation
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
from dbgpt.util.executor_utils import ExecutorFactory
|
||||
from dbgpt.util.json_utils import serialize
|
||||
from dbgpt.util.tracer import TracerManager
|
||||
from dbgpt_app.dbgpt_server import system_app
|
||||
@@ -131,12 +134,27 @@ class MultiAgents(BaseComponent, ABC):
|
||||
return self.gpts_app.app_detail(app_code)
|
||||
|
||||
def get_or_build_agent_memory(self, conv_id: str, dbgpts_name: str) -> AgentMemory:
|
||||
memory_key = f"{dbgpts_name}_{conv_id}"
|
||||
if memory_key in self.agent_memory_map:
|
||||
return self.agent_memory_map[memory_key]
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt_serve.rag.storage_manager import StorageManager
|
||||
|
||||
executor = self.system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
storage_manager = StorageManager.get_instance(self.system_app)
|
||||
vector_store = storage_manager.create_vector_store(index_name="_agent_memory_")
|
||||
embeddings = EmbeddingFactory.get_instance(self.system_app).create()
|
||||
short_term_memory = EnhancedShortTermMemory(
|
||||
embeddings, executor=executor, buffer_size=10
|
||||
)
|
||||
memory = HybridMemory.from_vstore(
|
||||
vector_store,
|
||||
embeddings=embeddings,
|
||||
executor=executor,
|
||||
short_term_memory=short_term_memory,
|
||||
)
|
||||
agent_memory = AgentMemory(memory, gpts_memory=self.memory)
|
||||
|
||||
agent_memory = AgentMemory(gpts_memory=self.memory)
|
||||
self.agent_memory_map[memory_key] = agent_memory
|
||||
return agent_memory
|
||||
|
||||
async def agent_chat_v2(
|
||||
|
@@ -84,7 +84,6 @@ class MetaDbGptsMessageMemory(GptsMessageMemory):
|
||||
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
|
||||
db_results = self.gpts_message.get_by_agent(conv_id, agent)
|
||||
results = []
|
||||
db_results = sorted(db_results, key=lambda x: x.rounds)
|
||||
for item in db_results:
|
||||
results.append(GptsMessage.from_dict(item.__dict__))
|
||||
return results
|
||||
@@ -120,3 +119,6 @@ class MetaDbGptsMessageMemory(GptsMessageMemory):
|
||||
return GptsMessage.from_dict(db_result.__dict__)
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_conv_id(self, conv_id: str) -> None:
|
||||
self.gpts_message.delete_chat_message(conv_id)
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -14,6 +15,7 @@ from sqlalchemy import (
|
||||
or_,
|
||||
)
|
||||
|
||||
from dbgpt.agent.util.conv_utils import parse_conv_id
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
@@ -111,19 +113,31 @@ class GptsMessagesDao(BaseDao):
|
||||
self, conv_id: str, agent: str
|
||||
) -> Optional[List[GptsMessagesEntity]]:
|
||||
session = self.get_raw_session()
|
||||
real_conv_id, _ = parse_conv_id(conv_id)
|
||||
gpts_messages = session.query(GptsMessagesEntity)
|
||||
if agent:
|
||||
gpts_messages = gpts_messages.filter(
|
||||
GptsMessagesEntity.conv_id == conv_id
|
||||
GptsMessagesEntity.conv_id.like(f"%{real_conv_id}%")
|
||||
).filter(
|
||||
or_(
|
||||
GptsMessagesEntity.sender == agent,
|
||||
GptsMessagesEntity.receiver == agent,
|
||||
)
|
||||
)
|
||||
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
||||
# Extract results first to apply custom sorting
|
||||
results = gpts_messages.all()
|
||||
|
||||
# Custom sorting based on conv_id suffix and rounds
|
||||
def get_suffix_number(entity):
|
||||
suffix_match = re.search(r"_(\d+)$", entity.conv_id)
|
||||
if suffix_match:
|
||||
return int(suffix_match.group(1))
|
||||
return 0 # Default for entries without a numeric suffix
|
||||
|
||||
# Sort first by numeric suffix, then by rounds
|
||||
sorted_results = sorted(results, key=lambda x: (get_suffix_number(x), x.rounds))
|
||||
session.close()
|
||||
return result
|
||||
return sorted_results
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessagesEntity]]:
|
||||
session = self.get_raw_session()
|
||||
|
Reference in New Issue
Block a user