diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index f132bedb9..64d3617bf 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -85,9 +85,10 @@ class ChatGLMAdapater(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper): """TODO Support guanaco""" + def match(self, model_path: str): return "guanaco" in model_path - + def loader(self, model_path: str, from_pretrained_kwargs: dict): tokenizer = LlamaTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_llm.py index 5def47302..84fd3eede 100644 --- a/pilot/model/guanaco_llm.py +++ b/pilot/model/guanaco_llm.py @@ -6,73 +6,68 @@ import transformers from transformers import GenerationConfig from llm_utils import Iteratorize, Stream + def guanaco_generate_output(model, tokenizer, params, device): """Fork from fastchat: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" prompt = params["prompt"] - inputs = tokenizer(prompt, return_tensors="pt") + inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) - temperature=0.5, - top_p=0.95, - top_k=45, - max_new_tokens=128, - stream_output=True + temperature = (0.5,) + top_p = (0.95,) + top_k = (45,) + max_new_tokens = (128,) + stream_output = True generation_config = GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - ) - + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) + generate_params = { - "input_ids": input_ids, - "generation_config": generation_config, - "return_dict_in_generate": True, - "output_scores": True, - "max_new_tokens": max_new_tokens, - } - + "input_ids": input_ids, + "generation_config": generation_config, + "return_dict_in_generate": True, + "output_scores": True, + "max_new_tokens": max_new_tokens, + } + if stream_output: - # Stream the reply 1 token at a time. - # This is based on the trick of using 'stopping_criteria' to create an iterator, - # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. + # Stream the reply 1 token at a time. + # This is based on the trick of using 'stopping_criteria' to create an iterator, + # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. - def generate_with_callback(callback=None, **kwargs): - kwargs.setdefault( - "stopping_criteria", transformers.StoppingCriteriaList() - ) - kwargs["stopping_criteria"].append( - Stream(callback_func=callback) - ) - with torch.no_grad(): - model.generate(**kwargs) + def generate_with_callback(callback=None, **kwargs): + kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList()) + kwargs["stopping_criteria"].append(Stream(callback_func=callback)) + with torch.no_grad(): + model.generate(**kwargs) - def generate_with_streaming(**kwargs): - return Iteratorize( - generate_with_callback, kwargs, callback=None - ) - - with generate_with_streaming(**generate_params) as generator: - for output in generator: - # new_tokens = len(output) - len(input_ids[0]) - decoded_output = tokenizer.decode(output) + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs, callback=None) - if output[-1] in [tokenizer.eos_token_id]: - break + with generate_with_streaming(**generate_params) as generator: + for output in generator: + # new_tokens = len(output) - len(input_ids[0]) + decoded_output = tokenizer.decode(output) + + if output[-1] in [tokenizer.eos_token_id]: + break + + yield decoded_output.split("### Response:")[-1].strip() + return # early return for stream_output - yield decoded_output.split("### Response:")[-1].strip() - return # early return for stream_output - with torch.no_grad(): - generation_output = model.generate( - input_ids=input_ids, - generation_config=generation_config, - return_dict_in_generate=True, - output_scores=True, - max_new_tokens=max_new_tokens, - ) - + generation_output = model.generate( + input_ids=input_ids, + generation_config=generation_config, + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=max_new_tokens, + ) + s = generation_output.sequences[0] - print(f"debug_sequences,{s}",s) + print(f"debug_sequences,{s}", s) output = tokenizer.decode(s) - print(f"debug_output,{output}",output) - yield output.split("### Response:")[-1].strip() \ No newline at end of file + print(f"debug_output,{output}", output) + yield output.split("### Response:")[-1].strip() diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index a8b354055..359d478f8 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -53,6 +53,7 @@ def create_chat_completion( response = None # TODO impl this use vicuna server api + class Stream(transformers.StoppingCriteria): def __init__(self, callback_func=None): self.callback_func = callback_func @@ -113,4 +114,4 @@ class Iteratorize: return self def __exit__(self, exc_type, exc_val, exc_tb): - self.stop_now = True \ No newline at end of file + self.stop_now = True diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 2968161b9..39737112b 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -85,14 +85,14 @@ class CodeGenChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter): - """Model chat adapter for Guanaco """ - + """Model chat adapter for Guanaco""" + def match(self, model_path: str): return "guanaco" in model_path def get_generate_stream_func(self): from pilot.model.guanaco_llm import guanaco_generate_output - + return guanaco_generate_output