From 6ea9c17aae387092cc634c0a26b66bd0d0373236 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Mon, 13 Nov 2023 23:43:52 +0800 Subject: [PATCH] fix:reference bug --- pilot/scene/base_chat.py | 4 ++++ pilot/scene/chat_knowledge/v1/chat.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 879213dcf..28729a689 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -174,6 +174,9 @@ class BaseChat(ABC): def stream_plugin_call(self, text): return text + def stream_call_reinforce_fn(self, text): + return text + async def check_iterator_end(iterator): try: await asyncio.anext(iterator) @@ -215,6 +218,7 @@ class BaseChat(ABC): view_msg = view_msg.replace("\n", "\\n") yield view_msg self.current_message.add_ai_message(msg) + view_msg = self.stream_call_reinforce_fn(view_msg) self.current_message.add_view_message(view_msg) span.end() except Exception as e: diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index e4e373cfe..62301ad12 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -89,7 +89,7 @@ class ChatKnowledge(BaseChat): last_output = last_output + reference yield last_output - def stream_plugin_call(self, text): + def stream_call_reinforce_fn(self, text): """return reference""" return text + f"\n\n{self.parse_source_view(self.sources)}"