diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 33d96f1c2..4b8b85a62 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = { "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), + "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 0ff368c70..f132bedb9 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -3,7 +3,7 @@ from functools import cache from typing import List -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer from pilot.configs.model_config import DEVICE @@ -85,8 +85,15 @@ class ChatGLMAdapater(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper): """TODO Support guanaco""" - - pass + def match(self, model_path: str): + return "guanaco" in model_path + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = LlamaTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs + ) + return model, tokenizer class CodeGenAdapter(BaseLLMAdaper): @@ -143,6 +150,7 @@ class ProxyllmAdapter(BaseLLMAdaper): register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) +register_llm_model_adapters(GuanacoAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_llm.py new file mode 100644 index 000000000..5def47302 --- /dev/null +++ b/pilot/model/guanaco_llm.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import torch +import transformers +from transformers import GenerationConfig +from llm_utils import Iteratorize, Stream + +def guanaco_generate_output(model, tokenizer, params, device): + """Fork from fastchat: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" + prompt = params["prompt"] + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"].to(device) + temperature=0.5, + top_p=0.95, + top_k=45, + max_new_tokens=128, + stream_output=True + + generation_config = GenerationConfig( + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) + + generate_params = { + "input_ids": input_ids, + "generation_config": generation_config, + "return_dict_in_generate": True, + "output_scores": True, + "max_new_tokens": max_new_tokens, + } + + if stream_output: + # Stream the reply 1 token at a time. + # This is based on the trick of using 'stopping_criteria' to create an iterator, + # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. + + def generate_with_callback(callback=None, **kwargs): + kwargs.setdefault( + "stopping_criteria", transformers.StoppingCriteriaList() + ) + kwargs["stopping_criteria"].append( + Stream(callback_func=callback) + ) + with torch.no_grad(): + model.generate(**kwargs) + + def generate_with_streaming(**kwargs): + return Iteratorize( + generate_with_callback, kwargs, callback=None + ) + + with generate_with_streaming(**generate_params) as generator: + for output in generator: + # new_tokens = len(output) - len(input_ids[0]) + decoded_output = tokenizer.decode(output) + + if output[-1] in [tokenizer.eos_token_id]: + break + + yield decoded_output.split("### Response:")[-1].strip() + return # early return for stream_output + + with torch.no_grad(): + generation_output = model.generate( + input_ids=input_ids, + generation_config=generation_config, + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=max_new_tokens, + ) + + s = generation_output.sequences[0] + print(f"debug_sequences,{s}",s) + output = tokenizer.decode(s) + print(f"debug_output,{output}",output) + yield output.split("### Response:")[-1].strip() \ No newline at end of file diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index 118d45f97..a8b354055 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import traceback +from queue import Queue +from threading import Thread +import transformers + from typing import List, Optional from pilot.configs.config import Config @@ -47,3 +52,65 @@ def create_chat_completion( response = None # TODO impl this use vicuna server api + +class Stream(transformers.StoppingCriteria): + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, input_ids, scores) -> bool: + if self.callback_func is not None: + self.callback_func(input_ids[0]) + return False + + +class Iteratorize: + + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + def __init__(self, func, kwargs={}, callback=None): + self.mfunc = func + self.c_callback = callback + self.q = Queue() + self.sentinel = object() + self.kwargs = kwargs + self.stop_now = False + + def _callback(val): + if self.stop_now: + raise ValueError + self.q.put(val) + + def gentask(): + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + pass + except: + traceback.print_exc() + pass + + self.q.put(self.sentinel) + if self.c_callback: + self.c_callback(ret) + + self.thread = Thread(target=gentask) + self.thread.start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True, None) + if obj is self.sentinel: + raise StopIteration + else: + return obj + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop_now = True \ No newline at end of file diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 17d2f95a8..2968161b9 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -85,14 +85,15 @@ class CodeGenChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter): - """Model chat adapter for Guanaco""" - + """Model chat adapter for Guanaco """ + def match(self, model_path: str): return "guanaco" in model_path def get_generate_stream_func(self): - # TODO - pass + from pilot.model.guanaco_llm import guanaco_generate_output + + return guanaco_generate_output class ProxyllmChatAdapter(BaseChatAdpter): @@ -107,6 +108,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) +register_llm_model_chat_adapter(GuanacoChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter)