diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_llm.py index df4f86dce..ba10b4f56 100644 --- a/pilot/model/guanaco_llm.py +++ b/pilot/model/guanaco_llm.py @@ -7,7 +7,7 @@ from transformers import GenerationConfig from pilot.model.llm_utils import Iteratorize, Stream -def guanaco_generate_output(model, tokenizer, params, device): +def guanaco_generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2): """Fork from fastchat: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" prompt = params["prompt"] inputs = tokenizer(prompt, return_tensors="pt")