feat(agent):Agent supports conversation context (#2230)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
明天
2024-12-20 16:50:08 +08:00
committed by GitHub
parent 16c5233a6d
commit 2b4597e6a7
13 changed files with 164 additions and 18 deletions

View File

@@ -206,11 +206,44 @@ class MultiAgents(BaseComponent, ABC):
if not gpt_app:
raise ValueError(f"Not found app {gpts_name}!")
historical_dialogues: List[GptsMessage] = []
if not is_retry_chat:
# 新建gpts对话记录
# Create a new gpts conversation record
gpt_app: GptsApp = self.gpts_app.app_detail(gpts_name)
if not gpt_app:
raise ValueError(f"Not found app {gpts_name}!")
## When creating a new gpts conversation record, determine whether to include the history of previous topics according to the application definition.
## TODO BEGIN
# Temporarily use system configuration management, and subsequently use application configuration management
if CFG.MESSAGES_KEEP_START_ROUNDS and CFG.MESSAGES_KEEP_START_ROUNDS > 0:
gpt_app.keep_start_rounds = CFG.MESSAGES_KEEP_START_ROUNDS
if CFG.MESSAGES_KEEP_END_ROUNDS and CFG.MESSAGES_KEEP_END_ROUNDS > 0:
gpt_app.keep_end_rounds = CFG.MESSAGES_KEEP_END_ROUNDS
## TODO END
if gpt_app.keep_start_rounds > 0 or gpt_app.keep_end_rounds > 0:
if gpts_conversations and len(gpts_conversations) > 0:
rely_conversations = []
if gpt_app.keep_start_rounds + gpt_app.keep_end_rounds < len(
gpts_conversations
):
if gpt_app.keep_start_rounds > 0:
front = gpts_conversations[gpt_app.keep_start_rounds :]
rely_conversations.extend(front)
if gpt_app.keep_end_rounds > 0:
back = gpts_conversations[-gpt_app.keep_end_rounds :]
rely_conversations.extend(back)
else:
rely_conversations = gpts_conversations
for gpts_conversation in rely_conversations:
temps: List[GptsMessage] = await self.memory.get_messages(
gpts_conversation.conv_id
)
if temps and len(temps) > 1:
historical_dialogues.append(temps[0])
historical_dialogues.append(temps[-1])
self.gpts_conversations.add(
GptsConversationsEntity(
conv_id=agent_conv_id,
@@ -277,6 +310,8 @@ class MultiAgents(BaseComponent, ABC):
is_retry_chat,
last_speaker_name=last_speaker_name,
init_message_rounds=message_round,
enable_verbose=enable_verbose,
historical_dialogues=historical_dialogues,
**ext_info,
)
)
@@ -418,6 +453,8 @@ class MultiAgents(BaseComponent, ABC):
link_sender: ConversableAgent = None,
app_link_start: bool = False,
enable_verbose: bool = True,
historical_dialogues: Optional[List[GptsMessage]] = None,
rely_messages: Optional[List[GptsMessage]] = None,
**ext_info,
):
gpts_status = Status.COMPLETE.value
@@ -529,6 +566,10 @@ class MultiAgents(BaseComponent, ABC):
is_retry_chat=is_retry_chat,
last_speaker_name=last_speaker_name,
message_rounds=init_message_rounds,
historical_dialogues=user_proxy.convert_to_agent_message(
historical_dialogues
),
rely_messages=rely_messages,
**ext_info,
)

View File

@@ -93,6 +93,8 @@ class StartAppAssistantAgent(ConversableAgent):
is_recovery: Optional[bool] = False,
is_retry_chat: bool = False,
last_speaker_name: str = None,
historical_dialogues: Optional[List[AgentMessage]] = None,
rely_messages: Optional[List[AgentMessage]] = None,
) -> None:
await self._a_process_received_message(message, sender)
if request_reply is False or request_reply is None:

View File

@@ -135,6 +135,10 @@ class GptsApp(BaseModel):
recommend_questions: Optional[List[RecommendQuestion]] = []
admins: List[str] = Field(default_factory=list)
# By default, keep the last two rounds of conversation records as the context
keep_start_rounds: int = 0
keep_end_rounds: int = 0
def to_dict(self):
return {k: self._serialize(v) for k, v in self.__dict__.items()}
@@ -170,6 +174,8 @@ class GptsApp(BaseModel):
owner_avatar_url=d.get("owner_avatar_url", None),
recommend_questions=d.get("recommend_questions", []),
admins=d.get("admins", []),
keep_start_rounds=d.get("keep_start_rounds", 0),
keep_end_rounds=d.get("keep_end_rounds", 2),
)
@model_validator(mode="before")
@@ -547,6 +553,8 @@ class GptsAppDao(BaseDao):
"published": app_info.published,
"details": [],
"admins": [],
# "keep_start_rounds": app_info.keep_start_rounds,
# "keep_end_rounds": app_info.keep_end_rounds,
}
)
for app_info in app_entities
@@ -918,6 +926,8 @@ class GptsAppDao(BaseDao):
app_entity.icon = gpts_app.icon
app_entity.team_context = _parse_team_context(gpts_app.team_context)
app_entity.param_need = json.dumps(gpts_app.param_need)
app_entity.keep_start_rounds = gpts_app.keep_start_rounds
app_entity.keep_end_rounds = gpts_app.keep_end_rounds
session.merge(app_entity)
old_details = session.query(GptsAppDetailEntity).filter(