mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-20 09:14:44 +00:00
fix: chatglm not working in doc qa, meta qa and plugin
This commit is contained in:
parent
e4681c9a9d
commit
189ac995ec
@ -1,5 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
import re
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -33,34 +36,36 @@ def chatglm_generate_stream(
|
|||||||
messages = prompt.split(stop)
|
messages = prompt.split(stop)
|
||||||
#
|
#
|
||||||
# # Add history conversation
|
# # Add history conversation
|
||||||
hist = []
|
hist = [HistoryEntry()]
|
||||||
once_conversation = []
|
system_messages = []
|
||||||
for message in messages[:-2]:
|
for message in messages[:-2]:
|
||||||
if len(message) <= 0:
|
if len(message) <= 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "human:" in message:
|
if "human:" in message:
|
||||||
once_conversation.append(message.split("human:")[1])
|
hist[-1].add_question(message.split("human:")[1])
|
||||||
# elif "system:" in message:
|
elif "system:" in message:
|
||||||
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
|
msg = message.split("system:")[1]
|
||||||
|
hist[-1].add_question(msg)
|
||||||
|
system_messages.append(msg)
|
||||||
elif "ai:" in message:
|
elif "ai:" in message:
|
||||||
once_conversation.append(message.split("ai:")[1])
|
hist[-1].add_answer(message.split("ai:")[1])
|
||||||
last_conversation = copy.deepcopy(once_conversation)
|
hist.append(HistoryEntry())
|
||||||
hist.append(last_conversation)
|
else:
|
||||||
once_conversation = []
|
# TODO
|
||||||
# else:
|
# hist[-1].add_question(message.split("system:")[1])
|
||||||
# once_conversation.append(f"""###system:{message} """)
|
# once_conversation.append(f"""###system:{message} """)
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = messages[-2].split("human:")[1]
|
query = messages[-2].split("human:")[1]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
# fix doc qa: https://github.com/csunny/DB-GPT/issues/274
|
query = messages[-3].split("human:")[1]
|
||||||
doc_qa_message = messages[-2]
|
hist = build_history(hist)
|
||||||
if "system:" in doc_qa_message:
|
if not hist:
|
||||||
query = doc_qa_message.split("system:")[1]
|
# No history conversation, but has system messages, merge to user`s query
|
||||||
else:
|
query = prompt_adaptation(system_messages, query)
|
||||||
query = messages[-3].split("human:")[1]
|
|
||||||
print("Query Message: ", query)
|
print("Query Message: ", query)
|
||||||
|
print("hist: ", hist)
|
||||||
# output = ""
|
# output = ""
|
||||||
# i = 0
|
# i = 0
|
||||||
|
|
||||||
@ -75,3 +80,43 @@ def chatglm_generate_stream(
|
|||||||
yield output
|
yield output
|
||||||
|
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryEntry:
|
||||||
|
def __init__(self, question: str = "", answer: str = ""):
|
||||||
|
self.question = question
|
||||||
|
self.answer = answer
|
||||||
|
|
||||||
|
def add_question(self, question: str):
|
||||||
|
self.question += question
|
||||||
|
|
||||||
|
def add_answer(self, answer: str):
|
||||||
|
self.answer += answer
|
||||||
|
|
||||||
|
def to_list(self):
|
||||||
|
if self.question == "" or self.answer == "":
|
||||||
|
return None
|
||||||
|
return [self.question, self.answer]
|
||||||
|
|
||||||
|
|
||||||
|
def build_history(hist: List[HistoryEntry]) -> List[List[str]]:
|
||||||
|
return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist)))
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_adaptation(system_messages: List[str], human_message: str) -> str:
|
||||||
|
if not system_messages:
|
||||||
|
return human_message
|
||||||
|
system_messages_str = " ".join(system_messages)
|
||||||
|
adaptation_rules = [
|
||||||
|
r"Question:\s*{}\s*", # chat_db scene
|
||||||
|
r"Goals:\s*{}\s*", # chat_execution
|
||||||
|
r"问题:\s*{}\s*", # chat_knowledge zh
|
||||||
|
r"question:\s*{}\s*", # chat_knowledge en
|
||||||
|
]
|
||||||
|
# system message has include human question
|
||||||
|
for rule in adaptation_rules:
|
||||||
|
pattern = re.compile(rule.format(re.escape(human_message)))
|
||||||
|
if re.search(pattern, system_messages_str):
|
||||||
|
return system_messages_str
|
||||||
|
# https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||||
|
return f"{system_messages_str}\n\n问:{human_message}\n\n答:"
|
||||||
|
Loading…
Reference in New Issue
Block a user