From c2bfab11e02d4581f24e77a3a77a3e3d18c64e0a Mon Sep 17 00:00:00 2001 From: zhanghy-sketchzh <1750410339@qq.com> Date: Wed, 7 Jun 2023 12:05:33 +0800 Subject: [PATCH] support gorilla --- pilot/configs/model_config.py | 1 + pilot/model/adapter.py | 15 ++++++++ pilot/model/llm_out/gorilla_llm.py | 58 ++++++++++++++++++++++++++++++ pilot/server/chat_adapter.py | 12 ++++++- 4 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 pilot/model/llm_out/gorilla_llm.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 759245864..4cef24489 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = { "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), + "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 05c55fa74..f5e5125cc 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -109,6 +109,20 @@ class GuanacoAdapter(BaseLLMAdaper): model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs ) return model, tokenizer + + +class GorillaAdapter(BaseLLMAdaper): + """TODO Support guanaco""" + + def match(self, model_path: str): + return "gorilla" in model_path + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + return model, tokenizer class CodeGenAdapter(BaseLLMAdaper): @@ -166,6 +180,7 @@ class ProxyllmAdapter(BaseLLMAdaper): register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) +register_llm_model_adapters(GorillaAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/llm_out/gorilla_llm.py b/pilot/model/llm_out/gorilla_llm.py new file mode 100644 index 000000000..406cb97d2 --- /dev/null +++ b/pilot/model/llm_out/gorilla_llm.py @@ -0,0 +1,58 @@ +import torch + +@torch.inference_mode() +def generate_stream( + model, tokenizer, params, device, context_len=42048, stream_interval=2 +): + """Fork from https://github.com/ShishirPatil/gorilla/blob/main/inference/serve/gorilla_cli.py""" + prompt = params["prompt"] + l_prompt = len(prompt) + max_new_tokens = int(params.get("max_new_tokens", 1024)) + stop_str = params.get("stop", None) + + input_ids = tokenizer(prompt).input_ids + output_ids = list(input_ids) + input_echo_len = len(input_ids) + max_src_len = context_len - max_new_tokens - 8 + input_ids = input_ids[-max_src_len:] + past_key_values = out = None + + for i in range(max_new_tokens): + if i == 0: + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + else: + out = model( + input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + past_key_values=past_key_values, + ) + logits = out.logits + past_key_values = out.past_key_values + + last_token_logits = logits[0][-1] + + probs = torch.softmax(last_token_logits, dim=-1) + token = int(torch.multinomial(probs, num_samples=1)) + output_ids.append(token) + + + if token == tokenizer.eos_token_id: + stopped = True + else: + stopped = False + + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + tmp_output_ids = output_ids[input_echo_len:] + output = tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False,) + pos = output.rfind(stop_str, l_prompt) + if pos != -1: + output = output[:pos] + stopped = True + yield output + + if stopped: + break + + del past_key_values \ No newline at end of file diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index d4ab8ae09..f87f0b24c 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -95,6 +95,16 @@ class GuanacoChatAdapter(BaseChatAdpter): return guanaco_generate_stream +class GorillaChatAdapter(BaseChatAdpter): + """Model chat adapter for Guanaco""" + + def match(self, model_path: str): + return "gorilla" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.gorilla_llm import generate_stream + + return generate_stream class ProxyllmChatAdapter(BaseChatAdpter): def match(self, model_path: str): @@ -109,7 +119,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) - +register_llm_model_chat_adapter(GorillaChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter)