From 343d64652cea28f664139dbe7134f3f6e28edcfe Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Thu, 16 Nov 2023 13:38:52 +0800 Subject: [PATCH] feat(ChatDB): ChatDB Use fintune model 1.Compatible with community pure sql output model --- pilot/model/model_adapter.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index 00ec7e736..dba9c5673 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -132,6 +132,9 @@ class LLMModelAdaper: conv = conv.copy() system_messages = [] + user_messages = [] + ai_messages = [] + for message in messages: role, content = None, None if isinstance(message, ModelMessage): @@ -147,20 +150,25 @@ class LLMModelAdaper: # Support for multiple system messages system_messages.append(content) elif role == ModelMessageRoleType.HUMAN: - conv.append_message(conv.roles[0], content) + # conv.append_message(conv.roles[0], content) + user_messages.append(content) elif role == ModelMessageRoleType.AI: - conv.append_message(conv.roles[1], content) + # conv.append_message(conv.roles[1], content) + ai_messages.append(content) else: raise ValueError(f"Unknown role: {role}") can_use_system = "" if system_messages: # TODO vicuna 兼容 测试完放弃 + user_messages[-1] = system_messages[-1] if len(system_messages) > 1: can_use_system = system_messages[0] - conv[-1][0][-1] =system_messages[-1] - elif len(system_messages) == 1: - conv[-1][0][-1] = system_messages[-1] + + for i in range(len(user_messages)): + conv.append_message(conv.roles[0], user_messages[i]) + if i < len(ai_messages): + conv.append_message(conv.roles[1], ai_messages[i]) if isinstance(conv, Conversation): conv.set_system_message(can_use_system)