mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 20:28:07 +00:00
fix: remove generate track
This commit is contained in:
parent
4fb7ed5a4b
commit
c84cf762cd
@ -32,30 +32,30 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048,
|
|||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
if stream_output:
|
# if stream_output:
|
||||||
# Stream the reply 1 token at a time.
|
# # Stream the reply 1 token at a time.
|
||||||
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
# # This is based on the trick of using 'stopping_criteria' to create an iterator,
|
||||||
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
# # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
||||||
|
|
||||||
def generate_with_callback(callback=None, **kwargs):
|
# def generate_with_callback(callback=None, **kwargs):
|
||||||
kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList())
|
# kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList())
|
||||||
kwargs["stopping_criteria"].append(Stream(callback_func=callback))
|
# kwargs["stopping_criteria"].append(Stream(callback_func=callback))
|
||||||
with torch.no_grad():
|
# with torch.no_grad():
|
||||||
model.generate(**kwargs)
|
# model.generate(**kwargs)
|
||||||
|
|
||||||
def generate_with_streaming(**kwargs):
|
# def generate_with_streaming(**kwargs):
|
||||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
# return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||||
|
|
||||||
with generate_with_streaming(**generate_params) as generator:
|
# with generate_with_streaming(**generate_params) as generator:
|
||||||
for output in generator:
|
# for output in generator:
|
||||||
# new_tokens = len(output) - len(input_ids[0])
|
# # new_tokens = len(output) - len(input_ids[0])
|
||||||
decoded_output = tokenizer.decode(output)
|
# decoded_output = tokenizer.decode(output)
|
||||||
|
|
||||||
if output[-1] in [tokenizer.eos_token_id]:
|
# if output[-1] in [tokenizer.eos_token_id]:
|
||||||
break
|
# break
|
||||||
|
|
||||||
yield decoded_output.split("### Response:")[-1].strip()
|
# yield decoded_output.split("### Response:")[-1].strip()
|
||||||
return # early return for stream_output
|
# return # early return for stream_output
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_output = model.generate(
|
generation_output = model.generate(
|
||||||
|
Loading…
Reference in New Issue
Block a user