diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_llm.py index 7c67dd22c..c6e91ee6f 100644 --- a/pilot/model/guanaco_llm.py +++ b/pilot/model/guanaco_llm.py @@ -1,73 +1,68 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - import torch -import transformers -from transformers import GenerationConfig -from pilot.model.llm_utils import Iteratorize, Stream +from threading import Thread +from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria +from pilot.conversation import ROLE_ASSISTANT, ROLE_USER - -def guanaco_generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2): +def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): """Fork from fastchat: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" - prompt = params["prompt"] - inputs = tokenizer(prompt, return_tensors="pt") - input_ids = inputs["input_ids"].to(device) - temperature = (0.5,) - top_p = (0.95,) - top_k = (45,) - max_new_tokens = (128,) - stream_output = True + stop = params.get("stop", "###") + messages = params["prompt"].split(stop) - generation_config = GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, + + hist = [] + for i in range(1, len(messages) - 2, 2): + hist.append( + ( + messages[i].split(ROLE_USER + ":")[1], + messages[i + 1].split(ROLE_ASSISTANT + ":")[1], + ) + ) + + + text = + "".join(["".join([f"### USER: {item[0]}\n",f"### Assistant: {item[1]}\n",])for item in hist[:-1]]) + text += "".join(["".join([f"### USER: {hist[-1][0]}\n",f"### Assistant: {hist[-1][1]}\n",])]) + + + query = messages[-2].split(ROLE_USER + ":")[1] + print("Query Message: ", query) + + input_ids = tokenizer(query, return_tensors="pt").input_ids + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + stop_token_ids = [0] + class StopOnTokens(StoppingCriteria): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + for stop_id in stop_token_ids: + if input_ids[0][-1] == stop_id: + return True + return False + + stop = StopOnTokens() + + generate_kwargs = dict( + input_ids=input_ids, + max_new_tokens=512, + temperature=1.0, + do_sample=True, + top_k=1, + streamer=streamer, + repetition_penalty=1.7, + stopping_criteria=StoppingCriteriaList([stop]) ) - generate_params = { - "input_ids": input_ids, - "generation_config": generation_config, - "return_dict_in_generate": True, - "output_scores": True, - "max_new_tokens": max_new_tokens, - } - # if stream_output: - # # Stream the reply 1 token at a time. - # # This is based on the trick of using 'stopping_criteria' to create an iterator, - # # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. + t1 = Thread(target=model.generate, kwargs=generate_kwargs) + t1.start() - # def generate_with_callback(callback=None, **kwargs): - # kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList()) - # kwargs["stopping_criteria"].append(Stream(callback_func=callback)) - # with torch.no_grad(): - # model.generate(**kwargs) + generator = model.generate(**generate_kwargs) + for output in generator: + # new_tokens = len(output) - len(input_ids[0]) + decoded_output = tokenizer.decode(output) + if output[-1] in [tokenizer.eos_token_id]: + break - # def generate_with_streaming(**kwargs): - # return Iteratorize(generate_with_callback, kwargs, callback=None) + out = decoded_output.split("### Response:")[-1].strip() - # with generate_with_streaming(**generate_params) as generator: - # for output in generator: - # # new_tokens = len(output) - len(input_ids[0]) - # decoded_output = tokenizer.decode(output) + yield out - # if output[-1] in [tokenizer.eos_token_id]: - # break - - # yield decoded_output.split("### Response:")[-1].strip() - # return # early return for stream_output - - with torch.no_grad(): - generation_output = model.generate( - input_ids=input_ids, - generation_config=generation_config, - return_dict_in_generate=True, - output_scores=True, - max_new_tokens=max_new_tokens, - ) - - s = generation_output.sequences[0] - print(f"debug_sequences,{s}", s) - output = tokenizer.decode(s) - print(f"debug_output,{output}", output) - yield output.split("### Response:")[-1].strip()