mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Inference]Add Streaming LLM (#5745)
* Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist
This commit is contained in:
@@ -48,6 +48,9 @@ def infer(args):
|
||||
block_size=16,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=args.use_cuda_kernel,
|
||||
enable_streamingllm=args.enable_streamingllm,
|
||||
start_token_size=args.start_token_size,
|
||||
generated_token_size=args.generated_token_size,
|
||||
)
|
||||
coordinator.print_on_master(f"Initializing Inference Engine...")
|
||||
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
|
||||
@@ -63,6 +66,8 @@ def infer(args):
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
)
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
|
||||
@@ -107,6 +112,25 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
|
||||
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
|
||||
parser.add_argument("--enable_streamingllm", action="store_true", help="Whether to use StreamingLLM")
|
||||
parser.add_argument(
|
||||
"--start_token_size", type=int, default=4, help="The size of the start_token, When using StreamingLLM,"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generated_token_size", type=int, default=512, help="The size of the generated_token, When using StreamingLLM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_repeat_ngram_size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repetition_penalty",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
infer(args)
|
||||
|
Reference in New Issue
Block a user