diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 37c4c423b..1a2d1ae8b 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -1,5 +1,4 @@ import torch -import copy from threading import Thread from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from pilot.conversation import ROLE_ASSISTANT, ROLE_USER @@ -57,3 +56,55 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): out = decoded_output.split("### Response:")[-1].strip() yield out + + +def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): + """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" + tokenizer.bos_token_id = 1 + print(params) + stop = params.get("stop", "###") + prompt = params["prompt"] + max_new_tokens = params.get("max_new_tokens", 512) + temerature = params.get("temperature", 1.0) + + query = prompt + print("Query Message: ", query) + + input_ids = tokenizer(query, return_tensors="pt").input_ids + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer( + tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True + ) + + tokenizer.bos_token_id = 1 + stop_token_ids = [0] + + class StopOnTokens(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + for stop_id in stop_token_ids: + if input_ids[-1][-1] == stop_id: + return True + return False + + stop = StopOnTokens() + + generate_kwargs = dict( + input_ids=input_ids, + max_new_tokens=max_new_tokens, + temperature=temerature, + do_sample=True, + top_k=1, + streamer=streamer, + repetition_penalty=1.7, + stopping_criteria=StoppingCriteriaList([stop]), + ) + + model.generate(**generate_kwargs) + + out = "" + for new_text in streamer: + out += new_text + yield out diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 92887cfc6..68512ec3c 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) "max_tokens": params.get("max_new_tokens"), } - print(payloads) - print(headers) res = requests.post( CFG.proxy_server_url, headers=headers, json=payloads, stream=True ) text = "" - print("====================================res================") - print(res) for line in res.iter_lines(): if line: decoded_line = line.decode("utf-8") diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 9fe6207c1..6fd6143ff 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -118,6 +118,8 @@ class ModelLoader(metaclass=Singleton): model.to(self.device) except ValueError: pass + except AttributeError: + pass if debug: print(model) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 0538aa54c..909023f07 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -56,8 +56,11 @@ class BaseOutputParser(ABC): # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() elif "guanaco" in CFG.LLM_MODEL: - # output = data["text"][skip_echo_len + 14:].replace("", "").strip() - output = data["text"][skip_echo_len:].replace("", "").strip() + # NO stream output + # output = data["text"][skip_echo_len + 2:].replace("", "").strip() + + # stream out output + output = data["text"][11:].replace("", "").strip() else: output = data["text"].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index cc2a0fc92..4743c4159 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -101,9 +101,9 @@ class GuanacoChatAdapter(BaseChatAdpter): return "guanaco" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.guanaco_llm import guanaco_generate_output + from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream - return guanaco_generate_output + return guanaco_generate_stream class ProxyllmChatAdapter(BaseChatAdpter):