From c2bfab11e02d4581f24e77a3a77a3e3d18c64e0a Mon Sep 17 00:00:00 2001 From: zhanghy-sketchzh <1750410339@qq.com> Date: Wed, 7 Jun 2023 12:05:33 +0800 Subject: [PATCH 1/9] support gorilla --- pilot/configs/model_config.py | 1 + pilot/model/adapter.py | 15 ++++++++ pilot/model/llm_out/gorilla_llm.py | 58 ++++++++++++++++++++++++++++++ pilot/server/chat_adapter.py | 12 ++++++- 4 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 pilot/model/llm_out/gorilla_llm.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 759245864..4cef24489 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = { "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), + "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 05c55fa74..f5e5125cc 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -109,6 +109,20 @@ class GuanacoAdapter(BaseLLMAdaper): model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs ) return model, tokenizer + + +class GorillaAdapter(BaseLLMAdaper): + """TODO Support guanaco""" + + def match(self, model_path: str): + return "gorilla" in model_path + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + return model, tokenizer class CodeGenAdapter(BaseLLMAdaper): @@ -166,6 +180,7 @@ class ProxyllmAdapter(BaseLLMAdaper): register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) +register_llm_model_adapters(GorillaAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/llm_out/gorilla_llm.py b/pilot/model/llm_out/gorilla_llm.py new file mode 100644 index 000000000..406cb97d2 --- /dev/null +++ b/pilot/model/llm_out/gorilla_llm.py @@ -0,0 +1,58 @@ +import torch + +@torch.inference_mode() +def generate_stream( + model, tokenizer, params, device, context_len=42048, stream_interval=2 +): + """Fork from https://github.com/ShishirPatil/gorilla/blob/main/inference/serve/gorilla_cli.py""" + prompt = params["prompt"] + l_prompt = len(prompt) + max_new_tokens = int(params.get("max_new_tokens", 1024)) + stop_str = params.get("stop", None) + + input_ids = tokenizer(prompt).input_ids + output_ids = list(input_ids) + input_echo_len = len(input_ids) + max_src_len = context_len - max_new_tokens - 8 + input_ids = input_ids[-max_src_len:] + past_key_values = out = None + + for i in range(max_new_tokens): + if i == 0: + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + else: + out = model( + input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + past_key_values=past_key_values, + ) + logits = out.logits + past_key_values = out.past_key_values + + last_token_logits = logits[0][-1] + + probs = torch.softmax(last_token_logits, dim=-1) + token = int(torch.multinomial(probs, num_samples=1)) + output_ids.append(token) + + + if token == tokenizer.eos_token_id: + stopped = True + else: + stopped = False + + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + tmp_output_ids = output_ids[input_echo_len:] + output = tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False,) + pos = output.rfind(stop_str, l_prompt) + if pos != -1: + output = output[:pos] + stopped = True + yield output + + if stopped: + break + + del past_key_values \ No newline at end of file diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index d4ab8ae09..f87f0b24c 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -95,6 +95,16 @@ class GuanacoChatAdapter(BaseChatAdpter): return guanaco_generate_stream +class GorillaChatAdapter(BaseChatAdpter): + """Model chat adapter for Guanaco""" + + def match(self, model_path: str): + return "gorilla" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.gorilla_llm import generate_stream + + return generate_stream class ProxyllmChatAdapter(BaseChatAdpter): def match(self, model_path: str): @@ -109,7 +119,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) - +register_llm_model_chat_adapter(GorillaChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) From b357fd9d0c4b42e576bfcc686470691a1aacf0a9 Mon Sep 17 00:00:00 2001 From: zhanghy-sketchzh <1750410339@qq.com> Date: Thu, 8 Jun 2023 12:17:13 +0800 Subject: [PATCH 2/9] Add support for falcon --- pilot/configs/model_config.py | 1 + pilot/model/adapter.py | 38 ++++++++++++++-------- pilot/model/llm_out/falcon_llm.py | 54 +++++++++++++++++++++++++++++++ pilot/server/chat_adapter.py | 13 ++++++-- 4 files changed, 91 insertions(+), 15 deletions(-) create mode 100644 pilot/model/llm_out/falcon_llm.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 759245864..adfc62f1a 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = { "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), + "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 05c55fa74..f8c65af77 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from functools import cache + +import torch from typing import List - -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer - +from functools import cache +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig from pilot.configs.model_config import DEVICE @@ -95,19 +95,30 @@ class GuanacoAdapter(BaseLLMAdaper): model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs ) return model, tokenizer - - -class GuanacoAdapter(BaseLLMAdaper): - """TODO Support guanaco""" + + +class FalconAdapater(BaseLLMAdaper): + """falcon Adapter""" def match(self, model_path: str): - return "guanaco" in model_path + return "falcon" in model_path - def loader(self, model_path: str, from_pretrained_kwargs: dict): - tokenizer = LlamaTokenizer.from_pretrained(model_path) + def loader(self, model_path: str, from_pretrained_kwagrs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="bfloat16", + bnb_4bit_use_double_quant=False, + ) model = AutoModelForCausalLM.from_pretrained( - model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs - ) + model_path, + #load_in_4bit=True, #quantize + quantization_config=bnb_config, + device_map={"": 0}, + trust_remote_code=True, + **from_pretrained_kwagrs + ) return model, tokenizer @@ -166,6 +177,7 @@ class ProxyllmAdapter(BaseLLMAdaper): register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) +register_llm_model_adapters(FalconAdapater) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/llm_out/falcon_llm.py b/pilot/model/llm_out/falcon_llm.py new file mode 100644 index 000000000..f4cb53eff --- /dev/null +++ b/pilot/model/llm_out/falcon_llm.py @@ -0,0 +1,54 @@ +import torch +import copy +from threading import Thread +from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria + + +def falcon_generate_output(model, tokenizer, params, device, context_len=2048): + """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" + tokenizer.bos_token_id = 1 + print(params) + stop = params.get("stop", "###") + prompt = params["prompt"] + query = prompt + print("Query Message: ", query) + + input_ids = tokenizer(query, return_tensors="pt").input_ids + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer( + tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True + ) + + tokenizer.bos_token_id = 1 + stop_token_ids = [0] + + class StopOnTokens(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + for stop_id in stop_token_ids: + if input_ids[0][-1] == stop_id: + return True + return False + + stop = StopOnTokens() + + generate_kwargs = dict( + input_ids=input_ids, + max_new_tokens=512, + temperature=1.0, + do_sample=True, + top_k=1, + streamer=streamer, + repetition_penalty=1.7, + stopping_criteria=StoppingCriteriaList([stop]), + ) + + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() + + out = "" + for new_text in streamer: + out += new_text + yield out diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index d4ab8ae09..1db3beee7 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -3,7 +3,6 @@ from functools import cache from typing import List - from pilot.model.llm_out.vicuna_base_llm import generate_stream @@ -95,7 +94,17 @@ class GuanacoChatAdapter(BaseChatAdpter): return guanaco_generate_stream +class FalconChatAdapter(BaseChatAdpter): + """Model chat adapter for Guanaco""" + def match(self, model_path: str): + return "falcon" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.falcon_llm import falcon_generate_output + + return falcon_generate_output + class ProxyllmChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "proxyllm" in model_path @@ -109,7 +118,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) - +register_llm_model_adapters(FalconChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) From bb9081e00fcc0df3519832ba305ee9abd341a7d8 Mon Sep 17 00:00:00 2001 From: zhanghy-sketchzh <1750410339@qq.com> Date: Thu, 8 Jun 2023 13:37:48 +0800 Subject: [PATCH 3/9] Add quantize_qlora support for falcon --- .env.template | 4 ++-- pilot/configs/model_config.py | 2 +- pilot/model/adapter.py | 27 +++++++++++++++------------ 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/.env.template b/.env.template index 234b12738..2fb5ff649 100644 --- a/.env.template +++ b/.env.template @@ -21,7 +21,7 @@ LLM_MODEL=vicuna-13b MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY=5 MAX_POSITION_EMBEDDINGS=4096 - +QUANTIZE_QLORA=True ## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b) ## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b) # SMART_LLM_MODEL=vicuna-13b @@ -112,4 +112,4 @@ PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address #*******************************************************************# # ** SUMMARY_CONFIG #*******************************************************************# -SUMMARY_CONFIG=FAST \ No newline at end of file +SUMMARY_CONFIG=FAST diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index adfc62f1a..4f3e635d6 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -42,7 +42,7 @@ LLM_MODEL_CONFIG = { # Load model config ISLOAD_8BIT = True ISDEBUG = False - +QLORA = os.getenv("QUANTIZE_QLORA") == "True" VECTOR_SEARCH_TOP_K = 10 VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index f8c65af77..76eb51f26 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -7,6 +7,7 @@ from functools import cache from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig from pilot.configs.model_config import DEVICE +bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False) class BaseLLMAdaper: """The Base class for multi model, in our project. @@ -105,19 +106,21 @@ class FalconAdapater(BaseLLMAdaper): def loader(self, model_path: str, from_pretrained_kwagrs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype="bfloat16", - bnb_4bit_use_double_quant=False, + if QLORA == True: + model = AutoModelForCausalLM.from_pretrained( + model_path, + load_in_4bit=True, #quantize + quantization_config=bnb_config, + device_map={"": 0}, + trust_remote_code=True, + **from_pretrained_kwagrs ) - model = AutoModelForCausalLM.from_pretrained( - model_path, - #load_in_4bit=True, #quantize - quantization_config=bnb_config, - device_map={"": 0}, - trust_remote_code=True, - **from_pretrained_kwagrs + else: + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + device_map={"": 0}, + **from_pretrained_kwagrs ) return model, tokenizer From c022e70a5025aa9a287b27e0f8201a68a35a812c Mon Sep 17 00:00:00 2001 From: zhanghy-sketchzh <1750410339@qq.com> Date: Thu, 8 Jun 2023 15:02:37 +0800 Subject: [PATCH 4/9] fix problems --- pilot/model/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index c9f5cb6f1..c914195d8 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -106,7 +106,7 @@ class FalconAdapater(BaseLLMAdaper): def loader(self, model_path: str, from_pretrained_kwagrs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - if QLORA == True: + if QLORA: model = AutoModelForCausalLM.from_pretrained( model_path, load_in_4bit=True, #quantize From 0948bc45bcbbbc2771483ab15e2bee6952f9c04e Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 8 Jun 2023 17:35:17 +0800 Subject: [PATCH 5/9] fix: gorilla chat adapter and config --- pilot/configs/config.py | 3 +++ pilot/model/adapter.py | 5 ++++- pilot/out_parser/base.py | 2 ++ pilot/server/chat_adapter.py | 13 ++++++++++++- 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pilot/configs/config.py b/pilot/configs/config.py index c4458eaf7..d14d16808 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -146,6 +146,9 @@ class Config(metaclass=Singleton): self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) + # QLoRA + self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") + ### EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500)) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index c914195d8..9da5cbd04 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -6,8 +6,10 @@ from typing import List from functools import cache from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig from pilot.configs.model_config import DEVICE +from pilot.configs.config import Config bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False) +CFG = Config() class BaseLLMAdaper: """The Base class for multi model, in our project. @@ -106,7 +108,8 @@ class FalconAdapater(BaseLLMAdaper): def loader(self, model_path: str, from_pretrained_kwagrs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - if QLORA: + + if CFG.QLoRA: model = AutoModelForCausalLM.from_pretrained( model_path, load_in_4bit=True, #quantize diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 909023f07..c91987579 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -61,6 +61,8 @@ class BaseOutputParser(ABC): # stream out output output = data["text"][11:].replace("", "").strip() + + # TODO gorilla and falcon output else: output = data["text"].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index a311312a2..b5c7128e7 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -116,10 +116,21 @@ class ProxyllmChatAdapter(BaseChatAdpter): return proxyllm_generate_stream +class GorillaChatAdapter(BaseChatAdpter): + + def match(self, model_path: str): + return "gorilla" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.gorilla_llm import generate_stream + + return generate_stream + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) -register_llm_model_adapters(FalconChatAdapter) +register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) # Proxy model for test and develop, it's cheap for us now. From e91868770559eea55e7e53675dc2db384ad0148e Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 8 Jun 2023 20:07:36 +0800 Subject: [PATCH 6/9] fix: remove qlora config from model_config --- pilot/configs/model_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 1997584a7..36d615043 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -43,7 +43,6 @@ LLM_MODEL_CONFIG = { # Load model config ISLOAD_8BIT = True ISDEBUG = False -QLORA = os.getenv("QUANTIZE_QLORA") == "True" VECTOR_SEARCH_TOP_K = 10 VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") From bea3f7c8d261b07174debe64ca4470ebb4c175b0 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 8 Jun 2023 21:29:02 +0800 Subject: [PATCH 7/9] fix: chatglm stream output --- pilot/model/llm_out/chatglm_llm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pilot/model/llm_out/chatglm_llm.py b/pilot/model/llm_out/chatglm_llm.py index 1a44678fe..690be0f06 100644 --- a/pilot/model/llm_out/chatglm_llm.py +++ b/pilot/model/llm_out/chatglm_llm.py @@ -51,7 +51,10 @@ def chatglm_generate_stream( # else: # once_conversation.append(f"""###system:{message} """) - query = messages[-2].split("human:")[1] + try: + query = messages[-2].split("human:")[1] + except IndexError: + query = messages[-3].split("human:")[1] print("Query Message: ", query) # output = "" # i = 0 From 716460460f37afff3b6db5e6138f7129af141ea7 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 8 Jun 2023 21:33:22 +0800 Subject: [PATCH 8/9] feature: gorllia support (#173) --- pilot/configs/config.py | 2 +- pilot/model/adapter.py | 40 ++++++++++++++++--------- pilot/model/llm_out/falcon_llm.py | 2 +- pilot/model/llm_out/gorilla_llm.py | 12 +++++--- pilot/out_parser/base.py | 2 +- pilot/server/chat_adapter.py | 8 ++--- pilot/source_embedding/pdf_embedding.py | 4 ++- 7 files changed, 44 insertions(+), 26 deletions(-) diff --git a/pilot/configs/config.py b/pilot/configs/config.py index d14d16808..971be9170 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -146,7 +146,7 @@ class Config(metaclass=Singleton): self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) - # QLoRA + # QLoRA self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") ### EMBEDDING Configuration diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 9da5cbd04..7892e4b1b 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -4,13 +4,25 @@ import torch from typing import List from functools import cache -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + LlamaTokenizer, + BitsAndBytesConfig, +) from pilot.configs.model_config import DEVICE from pilot.configs.config import Config -bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False) +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype="bfloat16", + bnb_4bit_use_double_quant=False, +) CFG = Config() + class BaseLLMAdaper: """The Base class for multi model, in our project. We will support those model, which performance resemble ChatGPT""" @@ -98,8 +110,8 @@ class GuanacoAdapter(BaseLLMAdaper): model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs ) return model, tokenizer - - + + class FalconAdapater(BaseLLMAdaper): """falcon Adapter""" @@ -111,23 +123,23 @@ class FalconAdapater(BaseLLMAdaper): if CFG.QLoRA: model = AutoModelForCausalLM.from_pretrained( - model_path, - load_in_4bit=True, #quantize - quantization_config=bnb_config, - device_map={"": 0}, - trust_remote_code=True, - **from_pretrained_kwagrs + model_path, + load_in_4bit=True, # quantize + quantization_config=bnb_config, + device_map={"": 0}, + trust_remote_code=True, + **from_pretrained_kwagrs, ) else: model = AutoModelForCausalLM.from_pretrained( - model_path, + model_path, trust_remote_code=True, device_map={"": 0}, - **from_pretrained_kwagrs + **from_pretrained_kwagrs, ) return model, tokenizer - - + + class GorillaAdapter(BaseLLMAdaper): """TODO Support guanaco""" diff --git a/pilot/model/llm_out/falcon_llm.py b/pilot/model/llm_out/falcon_llm.py index f4cb53eff..53eaffdfb 100644 --- a/pilot/model/llm_out/falcon_llm.py +++ b/pilot/model/llm_out/falcon_llm.py @@ -19,7 +19,7 @@ def falcon_generate_output(model, tokenizer, params, device, context_len=2048): streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) - + tokenizer.bos_token_id = 1 stop_token_ids = [0] diff --git a/pilot/model/llm_out/gorilla_llm.py b/pilot/model/llm_out/gorilla_llm.py index 406cb97d2..30360da77 100644 --- a/pilot/model/llm_out/gorilla_llm.py +++ b/pilot/model/llm_out/gorilla_llm.py @@ -1,5 +1,6 @@ import torch + @torch.inference_mode() def generate_stream( model, tokenizer, params, device, context_len=42048, stream_interval=2 @@ -22,7 +23,7 @@ def generate_stream( out = model(torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values - else: + else: out = model( input_ids=torch.as_tensor([[token]], device=device), use_cache=True, @@ -37,7 +38,6 @@ def generate_stream( token = int(torch.multinomial(probs, num_samples=1)) output_ids.append(token) - if token == tokenizer.eos_token_id: stopped = True else: @@ -45,7 +45,11 @@ def generate_stream( if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: tmp_output_ids = output_ids[input_echo_len:] - output = tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False,) + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + ) pos = output.rfind(stop_str, l_prompt) if pos != -1: output = output[:pos] @@ -55,4 +59,4 @@ def generate_stream( if stopped: break - del past_key_values \ No newline at end of file + del past_key_values diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index c91987579..513c1d300 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -62,7 +62,7 @@ class BaseOutputParser(ABC): # stream out output output = data["text"][11:].replace("", "").strip() - # TODO gorilla and falcon output + # TODO gorilla and falcon output else: output = data["text"].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index b5c7128e7..e4f57cf46 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -94,7 +94,7 @@ class GuanacoChatAdapter(BaseChatAdpter): return guanaco_generate_stream - + class FalconChatAdapter(BaseChatAdpter): """Model chat adapter for Guanaco""" @@ -105,7 +105,8 @@ class FalconChatAdapter(BaseChatAdpter): from pilot.model.llm_out.falcon_llm import falcon_generate_output return falcon_generate_output - + + class ProxyllmChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "proxyllm" in model_path @@ -116,8 +117,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): return proxyllm_generate_stream -class GorillaChatAdapter(BaseChatAdpter): - +class GorillaChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "gorilla" in model_path diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index aee498b31..ae8dde974 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -28,7 +28,9 @@ class PDFEmbedding(SourceEmbedding): # textsplitter = CHNDocumentSplitter( # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # ) - textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=1000, chunk_overlap=200) + textsplitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200 + ) return loader.load_and_split(textsplitter) @register From 40b55a2d27c7d249e7a5e3fdfae542f332e4ae37 Mon Sep 17 00:00:00 2001 From: csunny Date: Fri, 9 Jun 2023 16:19:47 +0800 Subject: [PATCH 9/9] fix: next-ui output --- pilot/model/llm_out/proxy_llm.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 68512ec3c..ff00620f1 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -76,7 +76,13 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) for line in res.iter_lines(): if line: decoded_line = line.decode("utf-8") - json_line = json.loads(decoded_line) - print(json_line) - text += json_line["choices"][0]["message"]["content"] - yield text + try: + json_line = json.loads(decoded_line) + print(json_line) + text += json_line["choices"][0]["message"]["content"] + yield text + except Exception as e: + text += decoded_line + yield json.loads(text)["choices"][0]["message"]["content"] + +