diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 7bbbac361..6a9eaca7f 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -34,7 +34,7 @@ LLM_MODEL_CONFIG = { "chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), - "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged" + "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "proxyllm": "proxyllm", diff --git a/pilot/model/guanaco_stream_llm.py b/pilot/model/guanaco_stream_llm.py deleted file mode 100644 index 8f72699d1..000000000 --- a/pilot/model/guanaco_stream_llm.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -from threading import Thread -from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria - - - -def guanaco_stream_generate_output(model, tokenizer, params, device, context_len=2048): - """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" - tokenizer.bos_token_id = 1 - print(params) - stop = params.get("stop", "###") - prompt = params["prompt"] - query = prompt - print("Query Message: ", query) - - input_ids = tokenizer(query, return_tensors="pt").input_ids - input_ids = input_ids.to(model.device) - - streamer = TextIteratorStreamer( - tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True - ) - - tokenizer.bos_token_id = 1 - stop_token_ids = [0] - - class StopOnTokens(StoppingCriteria): - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs - ) -> bool: - for stop_id in stop_token_ids: - if input_ids[0][-1] == stop_id: - return True - return False - - stop = StopOnTokens() - - generate_kwargs = dict( - input_ids=input_ids, - max_new_tokens=512, - temperature=1.0, - do_sample=True, - top_k=1, - streamer=streamer, - repetition_penalty=1.7, - stopping_criteria=StoppingCriteriaList([stop]), - ) - - t = Thread(target=model.generate, kwargs=generate_kwargs) - t.start() - - out = "" - for new_text in streamer: - out += new_text - yield new_text - return out \ No newline at end of file diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 4743c4159..6a1c5ce7b 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -59,17 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter): return chatglm_generate_stream -class GuanacoChatAdapter(BaseChatAdpter): - """Model chat adapter for Guanaco""" - - def match(self, model_path: str): - return "guanaco" in model_path - - def get_generate_stream_func(self): - from pilot.model.llm_out.guanaco_stream_llm import guanaco_stream_generate_output - - return guanaco_generate_output - class CodeT5ChatAdapter(BaseChatAdpter): """Model chat adapter for CodeT5"""