fix: remove generate track

This commit is contained in:
csunny 2023-05-31 14:25:06 +08:00
parent 4fb7ed5a4b
commit c84cf762cd

View File

@ -32,30 +32,30 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048,
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
} }
if stream_output: # if stream_output:
# Stream the reply 1 token at a time. # # Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator, # # This is based on the trick of using 'stopping_criteria' to create an iterator,
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. # # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
def generate_with_callback(callback=None, **kwargs): # def generate_with_callback(callback=None, **kwargs):
kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList()) # kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList())
kwargs["stopping_criteria"].append(Stream(callback_func=callback)) # kwargs["stopping_criteria"].append(Stream(callback_func=callback))
with torch.no_grad(): # with torch.no_grad():
model.generate(**kwargs) # model.generate(**kwargs)
def generate_with_streaming(**kwargs): # def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None) # return Iteratorize(generate_with_callback, kwargs, callback=None)
with generate_with_streaming(**generate_params) as generator: # with generate_with_streaming(**generate_params) as generator:
for output in generator: # for output in generator:
# new_tokens = len(output) - len(input_ids[0]) # # new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output) # decoded_output = tokenizer.decode(output)
if output[-1] in [tokenizer.eos_token_id]: # if output[-1] in [tokenizer.eos_token_id]:
break # break
yield decoded_output.split("### Response:")[-1].strip() # yield decoded_output.split("### Response:")[-1].strip()
return # early return for stream_output # return # early return for stream_output
with torch.no_grad(): with torch.no_grad():
generation_output = model.generate( generation_output = model.generate(