diff --git a/.env.template b/.env.template index 2ed7932f3..b06ef215e 100644 --- a/.env.template +++ b/.env.template @@ -101,8 +101,8 @@ LANGUAGE=en #*******************************************************************# # ** PROXY_SERVER #*******************************************************************# -PROXY_API_KEY=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 -PROXY_SERVER_URL=http://127.0.0.1:3000/api/openai/v1/chat/completions +PROXY_API_KEY= +PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address #*******************************************************************# diff --git a/README.zh.md b/README.zh.md index 3db20702c..06260c9b1 100644 --- a/README.zh.md +++ b/README.zh.md @@ -260,7 +260,7 @@ Run the Python interpreter and type the commands: 这是一个用于数据库的复杂且创新的工具, 我们的项目也在紧急的开发当中, 会陆续发布一些新的feature。如在使用当中有任何具体问题, 优先在项目下提issue, 如有需要, 请联系如下微信,我会尽力提供帮助,同时也非常欢迎大家参与到项目建设中。

- +

## Licence diff --git a/assets/DB_GPT_wechat.png b/assets/DB_GPT_wechat.png deleted file mode 100644 index a1d7f7558..000000000 Binary files a/assets/DB_GPT_wechat.png and /dev/null differ diff --git a/assets/wechat.jpg b/assets/wechat.jpg index 77ddbc89a..0a0c3c003 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 33d96f1c2..4b8b85a62 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = { "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), + "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 0ff368c70..64d3617bf 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -3,7 +3,7 @@ from functools import cache from typing import List -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer from pilot.configs.model_config import DEVICE @@ -86,7 +86,15 @@ class ChatGLMAdapater(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper): """TODO Support guanaco""" - pass + def match(self, model_path: str): + return "guanaco" in model_path + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = LlamaTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs + ) + return model, tokenizer class CodeGenAdapter(BaseLLMAdaper): @@ -143,6 +151,7 @@ class ProxyllmAdapter(BaseLLMAdaper): register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) +register_llm_model_adapters(GuanacoAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_llm.py new file mode 100644 index 000000000..03f2d1687 --- /dev/null +++ b/pilot/model/guanaco_llm.py @@ -0,0 +1,65 @@ +import torch +from threading import Thread +from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria +from pilot.conversation import ROLE_ASSISTANT, ROLE_USER + +def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): + """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" + + print(params) + stop = params.get("stop", "###") + messages = params["prompt"] + + + hist = [] + for i in range(1, len(messages) - 2, 2): + hist.append( + ( + messages[i].split(ROLE_USER + ":")[1], + messages[i + 1].split(ROLE_ASSISTANT + ":")[1], + ) + ) + + query = messages[-2].split(ROLE_USER + ":")[1] + print("Query Message: ", query) + + input_ids = tokenizer(query, return_tensors="pt").input_ids + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + stop_token_ids = [0] + class StopOnTokens(StoppingCriteria): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + for stop_id in stop_token_ids: + if input_ids[0][-1] == stop_id: + return True + return False + + stop = StopOnTokens() + + generate_kwargs = dict( + input_ids=input_ids, + max_new_tokens=512, + temperature=1.0, + do_sample=True, + top_k=1, + streamer=streamer, + repetition_penalty=1.7, + stopping_criteria=StoppingCriteriaList([stop]) + ) + + + t1 = Thread(target=model.generate, kwargs=generate_kwargs) + t1.start() + + generator = model.generate(**generate_kwargs) + for output in generator: + # new_tokens = len(output) - len(input_ids[0]) + decoded_output = tokenizer.decode(output) + if output[-1] in [tokenizer.eos_token_id]: + break + + out = decoded_output.split("### Response:")[-1].strip() + + yield out + diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index 118d45f97..359d478f8 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import traceback +from queue import Queue +from threading import Thread +import transformers + from typing import List, Optional from pilot.configs.config import Config @@ -47,3 +52,66 @@ def create_chat_completion( response = None # TODO impl this use vicuna server api + + +class Stream(transformers.StoppingCriteria): + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, input_ids, scores) -> bool: + if self.callback_func is not None: + self.callback_func(input_ids[0]) + return False + + +class Iteratorize: + + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + def __init__(self, func, kwargs={}, callback=None): + self.mfunc = func + self.c_callback = callback + self.q = Queue() + self.sentinel = object() + self.kwargs = kwargs + self.stop_now = False + + def _callback(val): + if self.stop_now: + raise ValueError + self.q.put(val) + + def gentask(): + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + pass + except: + traceback.print_exc() + pass + + self.q.put(self.sentinel) + if self.c_callback: + self.c_callback(ret) + + self.thread = Thread(target=gentask) + self.thread.start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True, None) + if obj is self.sentinel: + raise StopIteration + else: + return obj + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop_now = True diff --git a/pilot/model/loader.py b/pilot/model/loader.py index a3d443da8..9fe6207c1 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -113,7 +113,11 @@ class ModelLoader(metaclass=Singleton): or self.device == "mps" and tokenizer ): - model.to(self.device) + # 4-bit not support this + try: + model.to(self.device) + except ValueError: + pass if debug: print(model) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index eb992c8eb..7a61c788f 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -108,7 +108,7 @@ class BaseOutputParser(ABC): if not self.is_stream_out: return self._parse_model_nostream_resp(response, self.sep) else: - return self._parse_model_stream_resp(response, self.sep, skip_echo_len) + return self._parse_model_stream_resp(response, self.sep) def parse_prompt_response(self, model_out_text) -> T: """ diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 17d2f95a8..39737112b 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -91,8 +91,9 @@ class GuanacoChatAdapter(BaseChatAdpter): return "guanaco" in model_path def get_generate_stream_func(self): - # TODO - pass + from pilot.model.guanaco_llm import guanaco_generate_output + + return guanaco_generate_output class ProxyllmChatAdapter(BaseChatAdpter): @@ -107,6 +108,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) +register_llm_model_chat_adapter(GuanacoChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter)