diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 83d8a3717..84cd699ac 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -69,10 +69,6 @@ class ChatGLMAdapater(BaseLLMAdaper): ).half().cuda() return model, tokenizer -class ZiYaLLaMaAdapter(BaseLLMAdaper): - # TODO - pass - class CodeGenAdapter(BaseLLMAdaper): pass diff --git a/pilot/model/chat.py b/pilot/model/chat.py deleted file mode 100644 index 97206f2d5..000000000 --- a/pilot/model/chat.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- - diff --git a/pilot/model/chatglm_llm.py b/pilot/model/chatglm_llm.py new file mode 100644 index 000000000..ef54e92d7 --- /dev/null +++ b/pilot/model/chatglm_llm.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import torch + +@torch.inference_mode() +def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2): + + """Generate text using chatglm model's chat api """ + messages = params["prompt"] + max_new_tokens = int(params.get("max_new_tokens", 256)) + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + echo = params.get("echo", True) + + generate_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": True if temperature > 1e-5 else False, + "top_p": top_p, + "logits_processor": None + } + + if temperature > 1e-5: + generate_kwargs["temperature"] = temperature + + hist = [] + for i in range(0, len(messages) - 2, 2): + hist.append(messages[i][1], messages[i + 1][1]) + + query = messages[-2][1] + output = "" + i = 0 + for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)): + if echo: + output = query + " " + response + else: + output = response + + yield output + + yield output \ No newline at end of file