[Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist
This commit is contained in:
yuehuayingxueluo
2024-06-05 10:51:19 +08:00
committed by GitHub
parent ee6fd38373
commit b45000f839
8 changed files with 276 additions and 12 deletions

View File

@@ -48,6 +48,9 @@ def infer(args):
block_size=16,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
enable_streamingllm=args.enable_streamingllm,
start_token_size=args.start_token_size,
generated_token_size=args.generated_token_size,
)
coordinator.print_on_master(f"Initializing Inference Engine...")
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
@@ -63,6 +66,8 @@ def infer(args):
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
no_repeat_ngram_size=args.no_repeat_ngram_size,
repetition_penalty=args.repetition_penalty,
)
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
@@ -107,6 +112,25 @@ if __name__ == "__main__":
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
parser.add_argument("--enable_streamingllm", action="store_true", help="Whether to use StreamingLLM")
parser.add_argument(
"--start_token_size", type=int, default=4, help="The size of the start_token, When using StreamingLLM,"
)
parser.add_argument(
"--generated_token_size", type=int, default=512, help="The size of the generated_token, When using StreamingLLM"
)
parser.add_argument(
"--no_repeat_ngram_size",
type=int,
default=0,
help="If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.0,
help="The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.",
)
args = parser.parse_args()
infer(args)