From bea3f7c8d261b07174debe64ca4470ebb4c175b0 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 8 Jun 2023 21:29:02 +0800 Subject: [PATCH 1/2] fix: chatglm stream output --- pilot/model/llm_out/chatglm_llm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pilot/model/llm_out/chatglm_llm.py b/pilot/model/llm_out/chatglm_llm.py index 1a44678fe..690be0f06 100644 --- a/pilot/model/llm_out/chatglm_llm.py +++ b/pilot/model/llm_out/chatglm_llm.py @@ -51,7 +51,10 @@ def chatglm_generate_stream( # else: # once_conversation.append(f"""###system:{message} """) - query = messages[-2].split("human:")[1] + try: + query = messages[-2].split("human:")[1] + except IndexError: + query = messages[-3].split("human:")[1] print("Query Message: ", query) # output = "" # i = 0 From 716460460f37afff3b6db5e6138f7129af141ea7 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 8 Jun 2023 21:33:22 +0800 Subject: [PATCH 2/2] feature: gorllia support (#173) --- pilot/configs/config.py | 2 +- pilot/model/adapter.py | 40 ++++++++++++++++--------- pilot/model/llm_out/falcon_llm.py | 2 +- pilot/model/llm_out/gorilla_llm.py | 12 +++++--- pilot/out_parser/base.py | 2 +- pilot/server/chat_adapter.py | 8 ++--- pilot/source_embedding/pdf_embedding.py | 4 ++- 7 files changed, 44 insertions(+), 26 deletions(-) diff --git a/pilot/configs/config.py b/pilot/configs/config.py index d14d16808..971be9170 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -146,7 +146,7 @@ class Config(metaclass=Singleton): self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) - # QLoRA + # QLoRA self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") ### EMBEDDING Configuration diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 9da5cbd04..7892e4b1b 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -4,13 +4,25 @@ import torch from typing import List from functools import cache -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + LlamaTokenizer, + BitsAndBytesConfig, +) from pilot.configs.model_config import DEVICE from pilot.configs.config import Config -bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False) +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="bfloat16", + bnb_4bit_use_double_quant=False, +) CFG = Config() + class BaseLLMAdaper: """The Base class for multi model, in our project. We will support those model, which performance resemble ChatGPT""" @@ -98,8 +110,8 @@ class GuanacoAdapter(BaseLLMAdaper): model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs ) return model, tokenizer - - + + class FalconAdapater(BaseLLMAdaper): """falcon Adapter""" @@ -111,23 +123,23 @@ class FalconAdapater(BaseLLMAdaper): if CFG.QLoRA: model = AutoModelForCausalLM.from_pretrained( - model_path, - load_in_4bit=True, #quantize - quantization_config=bnb_config, - device_map={"": 0}, - trust_remote_code=True, - **from_pretrained_kwagrs + model_path, + load_in_4bit=True, # quantize + quantization_config=bnb_config, + device_map={"": 0}, + trust_remote_code=True, + **from_pretrained_kwagrs, ) else: model = AutoModelForCausalLM.from_pretrained( - model_path, + model_path, trust_remote_code=True, device_map={"": 0}, - **from_pretrained_kwagrs + **from_pretrained_kwagrs, ) return model, tokenizer - - + + class GorillaAdapter(BaseLLMAdaper): """TODO Support guanaco""" diff --git a/pilot/model/llm_out/falcon_llm.py b/pilot/model/llm_out/falcon_llm.py index f4cb53eff..53eaffdfb 100644 --- a/pilot/model/llm_out/falcon_llm.py +++ b/pilot/model/llm_out/falcon_llm.py @@ -19,7 +19,7 @@ def falcon_generate_output(model, tokenizer, params, device, context_len=2048): streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) - + tokenizer.bos_token_id = 1 stop_token_ids = [0] diff --git a/pilot/model/llm_out/gorilla_llm.py b/pilot/model/llm_out/gorilla_llm.py index 406cb97d2..30360da77 100644 --- a/pilot/model/llm_out/gorilla_llm.py +++ b/pilot/model/llm_out/gorilla_llm.py @@ -1,5 +1,6 @@ import torch + @torch.inference_mode() def generate_stream( model, tokenizer, params, device, context_len=42048, stream_interval=2 @@ -22,7 +23,7 @@ def generate_stream( out = model(torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values - else: + else: out = model( input_ids=torch.as_tensor([[token]], device=device), use_cache=True, @@ -37,7 +38,6 @@ def generate_stream( token = int(torch.multinomial(probs, num_samples=1)) output_ids.append(token) - if token == tokenizer.eos_token_id: stopped = True else: @@ -45,7 +45,11 @@ def generate_stream( if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: tmp_output_ids = output_ids[input_echo_len:] - output = tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False,) + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + ) pos = output.rfind(stop_str, l_prompt) if pos != -1: output = output[:pos] @@ -55,4 +59,4 @@ def generate_stream( if stopped: break - del past_key_values \ No newline at end of file + del past_key_values diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index c91987579..513c1d300 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -62,7 +62,7 @@ class BaseOutputParser(ABC): # stream out output output = data["text"][11:].replace("", "").strip() - # TODO gorilla and falcon output + # TODO gorilla and falcon output else: output = data["text"].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index b5c7128e7..e4f57cf46 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -94,7 +94,7 @@ class GuanacoChatAdapter(BaseChatAdpter): return guanaco_generate_stream - + class FalconChatAdapter(BaseChatAdpter): """Model chat adapter for Guanaco""" @@ -105,7 +105,8 @@ class FalconChatAdapter(BaseChatAdpter): from pilot.model.llm_out.falcon_llm import falcon_generate_output return falcon_generate_output - + + class ProxyllmChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "proxyllm" in model_path @@ -116,8 +117,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): return proxyllm_generate_stream -class GorillaChatAdapter(BaseChatAdpter): - +class GorillaChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "gorilla" in model_path diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index aee498b31..ae8dde974 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -28,7 +28,9 @@ class PDFEmbedding(SourceEmbedding): # textsplitter = CHNDocumentSplitter( # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # ) - textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=1000, chunk_overlap=200) + textsplitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200 + ) return loader.load_and_split(textsplitter) @register