From 4170074f324f38f1598c04336420b6711ed821c2 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 4 Jun 2023 19:13:23 +0800 Subject: [PATCH 1/5] stream output for guanaco --- pilot/model/llm_out/guanaco_llm.py | 51 +++++++++++++++++++++++++++++- pilot/server/chat_adapter.py | 16 +++++----- 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 37c4c423b..5b24e69ec 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -1,5 +1,4 @@ import torch -import copy from threading import Thread from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from pilot.conversation import ROLE_ASSISTANT, ROLE_USER @@ -57,3 +56,53 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): out = decoded_output.split("### Response:")[-1].strip() yield out + + +def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): + """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" + tokenizer.bos_token_id = 1 + print(params) + stop = params.get("stop", "###") + prompt = params["prompt"] + query = prompt + print("Query Message: ", query) + + input_ids = tokenizer(query, return_tensors="pt").input_ids + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer( + tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True + ) + + tokenizer.bos_token_id = 1 + stop_token_ids = [0] + + class StopOnTokens(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + for stop_id in stop_token_ids: + if input_ids[0][-1] == stop_id: + return True + return False + + stop = StopOnTokens() + + generate_kwargs = dict( + input_ids=input_ids, + max_new_tokens=512, + temperature=1.0, + do_sample=True, + top_k=1, + streamer=streamer, + repetition_penalty=1.7, + stopping_criteria=StoppingCriteriaList([stop]), + ) + + + generator = model.generate(**generate_kwargs) + out = "" + for new_text in streamer: + out += new_text + yield new_text + return out \ No newline at end of file diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 86901dea3..63d922672 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -4,7 +4,7 @@ from functools import cache from typing import List -from pilot.model.llm_out.vicuna_base_llm import generate_stream +from pilot.model.inference import generate_stream class BaseChatAdpter: @@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter): return "chatglm" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream + from pilot.model.chatglm_llm import chatglm_generate_stream return chatglm_generate_stream @@ -85,15 +85,15 @@ class CodeGenChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter): - """Model chat adapter for Guanaco""" - + """Model chat adapter for Guanaco """ + def match(self, model_path: str): return "guanaco" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.guanaco_llm import guanaco_generate_output - - return guanaco_generate_output + from pilot.model.guanaco_llm import guanaco_generate_stream + + return guanaco_generate_stream class ProxyllmChatAdapter(BaseChatAdpter): @@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): return "proxyllm" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream + from pilot.model.proxy_llm import proxyllm_generate_stream return proxyllm_generate_stream From f7fe66b5e560aec07513d07273931cbd36d34550 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 4 Jun 2023 20:15:27 +0800 Subject: [PATCH 2/5] fix: guanaco output --- pilot/out_parser/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 0538aa54c..bb2d0b2b2 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -57,7 +57,7 @@ class BaseOutputParser(ABC): output = data["text"][skip_echo_len:].strip() elif "guanaco" in CFG.LLM_MODEL: # output = data["text"][skip_echo_len + 14:].replace("", "").strip() - output = data["text"][skip_echo_len:].replace("", "").strip() + output = data["text"][skip_echo_len + 2:].replace("", "").strip() else: output = data["text"].strip() From fe8291b198e0d91f1d57e81d1b44e1221a14c501 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 4 Jun 2023 20:38:34 +0800 Subject: [PATCH 3/5] feature: guanaco stream output --- pilot/model/loader.py | 2 ++ pilot/server/chat_adapter.py | 14 +++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 9fe6207c1..6fd6143ff 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -118,6 +118,8 @@ class ModelLoader(metaclass=Singleton): model.to(self.device) except ValueError: pass + except AttributeError: + pass if debug: print(model) diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 63d922672..8db61d09f 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -4,7 +4,7 @@ from functools import cache from typing import List -from pilot.model.inference import generate_stream +from pilot.model.llm_out.vicuna_base_llm import generate_stream class BaseChatAdpter: @@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter): return "chatglm" in model_path def get_generate_stream_func(self): - from pilot.model.chatglm_llm import chatglm_generate_stream + from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream return chatglm_generate_stream @@ -85,14 +85,14 @@ class CodeGenChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter): - """Model chat adapter for Guanaco """ - + """Model chat adapter for Guanaco""" + def match(self, model_path: str): return "guanaco" in model_path def get_generate_stream_func(self): - from pilot.model.guanaco_llm import guanaco_generate_stream - + from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream + return guanaco_generate_stream @@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): return "proxyllm" in model_path def get_generate_stream_func(self): - from pilot.model.proxy_llm import proxyllm_generate_stream + from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream return proxyllm_generate_stream From ff6cc05e1146b723fd37e6fee85fecb09ca59333 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 4 Jun 2023 21:20:09 +0800 Subject: [PATCH 4/5] guanaco: add stream output func (#154) --- pilot/model/llm_out/guanaco_llm.py | 15 +++++++++------ pilot/model/llm_out/proxy_llm.py | 4 ---- pilot/out_parser/base.py | 8 ++++++-- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 5b24e69ec..9b8008702 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -64,6 +64,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): print(params) stop = params.get("stop", "###") prompt = params["prompt"] + max_new_tokens = params.get("max_new_tokens", 512) + temerature = params.get("temperature", 1.0) + query = prompt print("Query Message: ", query) @@ -82,7 +85,7 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: for stop_id in stop_token_ids: - if input_ids[0][-1] == stop_id: + if input_ids[-1][-1] == stop_id: return True return False @@ -90,8 +93,8 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): generate_kwargs = dict( input_ids=input_ids, - max_new_tokens=512, - temperature=1.0, + max_new_tokens=max_new_tokens, + temperature=temerature, do_sample=True, top_k=1, streamer=streamer, @@ -100,9 +103,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): ) - generator = model.generate(**generate_kwargs) + model.generate(**generate_kwargs) + out = "" for new_text in streamer: out += new_text - yield new_text - return out \ No newline at end of file + yield out \ No newline at end of file diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 92887cfc6..68512ec3c 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) "max_tokens": params.get("max_new_tokens"), } - print(payloads) - print(headers) res = requests.post( CFG.proxy_server_url, headers=headers, json=payloads, stream=True ) text = "" - print("====================================res================") - print(res) for line in res.iter_lines(): if line: decoded_line = line.decode("utf-8") diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index bb2d0b2b2..d1dee2e37 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -56,8 +56,12 @@ class BaseOutputParser(ABC): # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() elif "guanaco" in CFG.LLM_MODEL: - # output = data["text"][skip_echo_len + 14:].replace("", "").strip() - output = data["text"][skip_echo_len + 2:].replace("", "").strip() + + # NO stream output + # output = data["text"][skip_echo_len + 2:].replace("", "").strip() + + # stream out output + output = data["text"][11:].replace("", "").strip() else: output = data["text"].strip() From e8a193ef467bf31752c16840b3d662bb8edfe618 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 4 Jun 2023 21:47:21 +0800 Subject: [PATCH 5/5] feature: stream output for guanaco (#154) --- pilot/model/llm_out/guanaco_llm.py | 5 ++--- pilot/out_parser/base.py | 1 - pilot/server/chat_adapter.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 9b8008702..1a2d1ae8b 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -76,7 +76,7 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) - + tokenizer.bos_token_id = 1 stop_token_ids = [0] @@ -102,10 +102,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): stopping_criteria=StoppingCriteriaList([stop]), ) - model.generate(**generate_kwargs) out = "" for new_text in streamer: out += new_text - yield out \ No newline at end of file + yield out diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index d1dee2e37..909023f07 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -56,7 +56,6 @@ class BaseOutputParser(ABC): # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() elif "guanaco" in CFG.LLM_MODEL: - # NO stream output # output = data["text"][skip_echo_len + 2:].replace("", "").strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 8db61d09f..4dec22655 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -91,7 +91,7 @@ class GuanacoChatAdapter(BaseChatAdpter): return "guanaco" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream + from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream return guanaco_generate_stream