refactor(agent): Agent modular refactoring (#1487)

This commit is contained in:
Fangyin Cheng
2024-05-07 09:45:26 +08:00
committed by GitHub
parent 2a418f91e8
commit 863b5404dd
86 changed files with 4513 additions and 967 deletions

View File

@@ -12,13 +12,13 @@ from dbgpt._private.config import Config
from dbgpt.agent.core.agent import Agent, AgentContext
from dbgpt.agent.core.agent_manage import get_agent_manager
from dbgpt.agent.core.base_agent import ConversableAgent
from dbgpt.agent.core.llm.llm import LLMConfig, LLMStrategyType
from dbgpt.agent.core.memory.agent_memory import AgentMemory
from dbgpt.agent.core.memory.gpts.gpts_memory import GptsMemory
from dbgpt.agent.core.plan import AutoPlanChatManager, DefaultAWELLayoutManager
from dbgpt.agent.core.schema import Status
from dbgpt.agent.core.user_proxy_agent import UserProxyAgent
from dbgpt.agent.memory.gpts_memory import GptsMemory
from dbgpt.agent.plan.awel.team_awel_layout import DefaultAWELLayoutManager
from dbgpt.agent.plan.team_auto_plan import AutoPlanChatManager
from dbgpt.agent.resource.resource_loader import ResourceLoader
from dbgpt.agent.util.llm.llm import LLMConfig, LLMStrategyType
from dbgpt.app.openapi.api_view_model import Result
from dbgpt.app.scene.base import ChatScene
from dbgpt.component import BaseComponent, ComponentType, SystemApp
@@ -82,6 +82,39 @@ class MultiAgents(BaseComponent, ABC):
plans_memory=MetaDbGptsPlansMemory(),
message_memory=MetaDbGptsMessageMemory(),
)
self.agent_memory_map = {}
super().__init__()
def get_or_build_agent_memory(self, conv_id: str, dbgpts_name: str) -> AgentMemory:
from dbgpt.agent.core.memory.agent_memory import (
AgentMemory,
AgentMemoryFragment,
)
from dbgpt.agent.core.memory.hybrid import HybridMemory
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
memory_key = f"{dbgpts_name}_{conv_id}"
if memory_key in self.agent_memory_map:
return self.agent_memory_map[memory_key]
embedding_factory = EmbeddingFactory.get_instance(CFG.SYSTEM_APP)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vstore_name = f"_chroma_agent_memory_{dbgpts_name}_{conv_id}"
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=VectorStoreConfig(
name=vstore_name, embedding_fn=embedding_fn
),
)
memory = HybridMemory[AgentMemoryFragment].from_vstore(vector_store_connector)
agent_memory = AgentMemory(memory, gpts_memory=self.memory)
self.agent_memory_map[memory_key] = agent_memory
return agent_memory
def gpts_create(self, entity: GptsInstanceEntity):
self.gpts_intance.add(entity)
@@ -101,6 +134,7 @@ class MultiAgents(BaseComponent, ABC):
user_query: str,
user_code: str = None,
sys_code: str = None,
agent_memory: Optional[AgentMemory] = None,
):
gpt_app: GptsApp = self.gpts_app.app_detail(gpts_name)
@@ -124,7 +158,7 @@ class MultiAgents(BaseComponent, ABC):
task = asyncio.create_task(
multi_agents.agent_team_chat_new(
user_query, agent_conv_id, gpt_app, is_retry_chat
user_query, agent_conv_id, gpt_app, is_retry_chat, agent_memory
)
)
@@ -170,9 +204,15 @@ class MultiAgents(BaseComponent, ABC):
agent_conv_id = conv_uid + "_" + str(current_message.chat_order)
agent_task = None
try:
agent_memory = self.get_or_build_agent_memory(conv_uid, gpts_name)
agent_conv_id = conv_uid + "_" + str(current_message.chat_order)
async for task, chunk in multi_agents.agent_chat(
agent_conv_id, gpts_name, user_query, user_code, sys_code
agent_conv_id,
gpts_name,
user_query,
user_code,
sys_code,
agent_memory,
):
agent_task = task
yield chunk
@@ -200,6 +240,7 @@ class MultiAgents(BaseComponent, ABC):
conv_uid: str,
gpts_app: GptsApp,
is_retry_chat: bool = False,
agent_memory: Optional[AgentMemory] = None,
):
employees: List[Agent] = []
# Prepare resource loader
@@ -235,10 +276,10 @@ class MultiAgents(BaseComponent, ABC):
agent = (
await cls()
.bind(context)
.bind(self.memory)
.bind(llm_config)
.bind(record.resources)
.bind(resource_loader)
.bind(agent_memory)
.build()
)
employees.append(agent)
@@ -256,9 +297,9 @@ class MultiAgents(BaseComponent, ABC):
raise ValueError(f"Unknown Agent Team Mode!{team_mode}")
manager = (
await manager.bind(context)
.bind(self.memory)
.bind(llm_config)
.bind(resource_loader)
.bind(agent_memory)
.build()
)
manager.hire(employees)
@@ -267,8 +308,8 @@ class MultiAgents(BaseComponent, ABC):
user_proxy: UserProxyAgent = (
await UserProxyAgent()
.bind(context)
.bind(self.memory)
.bind(resource_loader)
.bind(agent_memory)
.build()
)
if is_retry_chat:

View File

@@ -1,7 +1,7 @@
from typing import List, Optional
from dbgpt.agent.memory.base import GptsPlan
from dbgpt.agent.memory.gpts_memory import (
from dbgpt.agent.core.memory.gpts.base import GptsPlan
from dbgpt.agent.core.memory.gpts.gpts_memory import (
GptsMessage,
GptsMessageMemory,
GptsPlansMemory,