diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 879213dcf..28729a689 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -174,6 +174,9 @@ class BaseChat(ABC): def stream_plugin_call(self, text): return text + def stream_call_reinforce_fn(self, text): + return text + async def check_iterator_end(iterator): try: await asyncio.anext(iterator) @@ -215,6 +218,7 @@ class BaseChat(ABC): view_msg = view_msg.replace("\n", "\\n") yield view_msg self.current_message.add_ai_message(msg) + view_msg = self.stream_call_reinforce_fn(view_msg) self.current_message.add_view_message(view_msg) span.end() except Exception as e: diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index e4e373cfe..62301ad12 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -89,7 +89,7 @@ class ChatKnowledge(BaseChat): last_output = last_output + reference yield last_output - def stream_plugin_call(self, text): + def stream_call_reinforce_fn(self, text): """return reference""" return text + f"\n\n{self.parse_source_view(self.sources)}"