diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 37c4c423b..5b24e69ec 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -1,5 +1,4 @@ import torch -import copy from threading import Thread from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from pilot.conversation import ROLE_ASSISTANT, ROLE_USER @@ -57,3 +56,53 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): out = decoded_output.split("### Response:")[-1].strip() yield out + + +def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): + """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" + tokenizer.bos_token_id = 1 + print(params) + stop = params.get("stop", "###") + prompt = params["prompt"] + query = prompt + print("Query Message: ", query) + + input_ids = tokenizer(query, return_tensors="pt").input_ids + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer( + tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True + ) + + tokenizer.bos_token_id = 1 + stop_token_ids = [0] + + class StopOnTokens(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + for stop_id in stop_token_ids: + if input_ids[0][-1] == stop_id: + return True + return False + + stop = StopOnTokens() + + generate_kwargs = dict( + input_ids=input_ids, + max_new_tokens=512, + temperature=1.0, + do_sample=True, + top_k=1, + streamer=streamer, + repetition_penalty=1.7, + stopping_criteria=StoppingCriteriaList([stop]), + ) + + + generator = model.generate(**generate_kwargs) + out = "" + for new_text in streamer: + out += new_text + yield new_text + return out \ No newline at end of file diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 86901dea3..63d922672 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -4,7 +4,7 @@ from functools import cache from typing import List -from pilot.model.llm_out.vicuna_base_llm import generate_stream +from pilot.model.inference import generate_stream class BaseChatAdpter: @@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter): return "chatglm" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream + from pilot.model.chatglm_llm import chatglm_generate_stream return chatglm_generate_stream @@ -85,15 +85,15 @@ class CodeGenChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter): - """Model chat adapter for Guanaco""" - + """Model chat adapter for Guanaco """ + def match(self, model_path: str): return "guanaco" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.guanaco_llm import guanaco_generate_output - - return guanaco_generate_output + from pilot.model.guanaco_llm import guanaco_generate_stream + + return guanaco_generate_stream class ProxyllmChatAdapter(BaseChatAdpter): @@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): return "proxyllm" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream + from pilot.model.proxy_llm import proxyllm_generate_stream return proxyllm_generate_stream