add chatglm support

This commit is contained in:
csunny 2023-05-21 16:30:03 +08:00
parent 5ec1f413b6
commit a3fae0bdf2
2 changed files with 7 additions and 2 deletions

View File

@ -15,6 +15,9 @@ DB_SETTINGS = {
"port": CFG.LOCAL_DB_PORT
}
ROLE_USER = "USER"
ROLE_ASSISTANT = "Assistant"
class SeparatorStyle(Enum):
SINGLE = auto()
TWO = auto()

View File

@ -3,6 +3,8 @@
import torch
from pilot.conversation import ROLE_USER, ROLE_ASSISTANT
@torch.inference_mode()
def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2):
@ -30,9 +32,9 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048,
# Add history chat to hist for model.
for i in range(1, len(messages) - 2, 2):
hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1]))
hist.append((messages[i].split(ROLE_USER + ":")[1], messages[i+1].split(ROLE_ASSISTANT + ":")[1]))
query = messages[-2].split(":")[1]
query = messages[-2].split(ROLE_USER + ":")[1]
print("Query Message: ", query)
output = ""
i = 0