From c84cf762cd55d6e3ccd0429ccbf53e67e1e3a88b Mon Sep 17 00:00:00 2001 From: csunny Date: Wed, 31 May 2023 14:25:06 +0800 Subject: [PATCH] fix: remove generate track --- pilot/model/guanaco_llm.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_llm.py index ba10b4f56..7c67dd22c 100644 --- a/pilot/model/guanaco_llm.py +++ b/pilot/model/guanaco_llm.py @@ -32,30 +32,30 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048, "max_new_tokens": max_new_tokens, } - if stream_output: - # Stream the reply 1 token at a time. - # This is based on the trick of using 'stopping_criteria' to create an iterator, - # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. + # if stream_output: + # # Stream the reply 1 token at a time. + # # This is based on the trick of using 'stopping_criteria' to create an iterator, + # # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. - def generate_with_callback(callback=None, **kwargs): - kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList()) - kwargs["stopping_criteria"].append(Stream(callback_func=callback)) - with torch.no_grad(): - model.generate(**kwargs) + # def generate_with_callback(callback=None, **kwargs): + # kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList()) + # kwargs["stopping_criteria"].append(Stream(callback_func=callback)) + # with torch.no_grad(): + # model.generate(**kwargs) - def generate_with_streaming(**kwargs): - return Iteratorize(generate_with_callback, kwargs, callback=None) + # def generate_with_streaming(**kwargs): + # return Iteratorize(generate_with_callback, kwargs, callback=None) - with generate_with_streaming(**generate_params) as generator: - for output in generator: - # new_tokens = len(output) - len(input_ids[0]) - decoded_output = tokenizer.decode(output) + # with generate_with_streaming(**generate_params) as generator: + # for output in generator: + # # new_tokens = len(output) - len(input_ids[0]) + # decoded_output = tokenizer.decode(output) - if output[-1] in [tokenizer.eos_token_id]: - break + # if output[-1] in [tokenizer.eos_token_id]: + # break - yield decoded_output.split("### Response:")[-1].strip() - return # early return for stream_output + # yield decoded_output.split("### Response:")[-1].strip() + # return # early return for stream_output with torch.no_grad(): generation_output = model.generate(