mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
add chatglm support
This commit is contained in:
parent
5ec1f413b6
commit
a3fae0bdf2
@ -15,6 +15,9 @@ DB_SETTINGS = {
|
||||
"port": CFG.LOCAL_DB_PORT
|
||||
}
|
||||
|
||||
ROLE_USER = "USER"
|
||||
ROLE_ASSISTANT = "Assistant"
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
|
@ -3,6 +3,8 @@
|
||||
|
||||
import torch
|
||||
|
||||
from pilot.conversation import ROLE_USER, ROLE_ASSISTANT
|
||||
|
||||
@torch.inference_mode()
|
||||
def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||
|
||||
@ -30,9 +32,9 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048,
|
||||
|
||||
# Add history chat to hist for model.
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1]))
|
||||
hist.append((messages[i].split(ROLE_USER + ":")[1], messages[i+1].split(ROLE_ASSISTANT + ":")[1]))
|
||||
|
||||
query = messages[-2].split(":")[1]
|
||||
query = messages[-2].split(ROLE_USER + ":")[1]
|
||||
print("Query Message: ", query)
|
||||
output = ""
|
||||
i = 0
|
||||
|
Loading…
Reference in New Issue
Block a user