mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
[example] add vllm inference benchmark (#5080)
* [example] add vllm inference benchmark * [example] update vllm inference benchmark
This commit is contained in:
parent
dce05da535
commit
42b2d6f3a5
89
examples/inference/benchmark_vllm.py
Normal file
89
examples/inference/benchmark_vllm.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
"""Modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_latency.py
|
||||
The only change is printing the throughput.
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# Process all the requests in a single batch if possible.
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_num_seqs=args.batch_size,
|
||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
dtype=args.dtype,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=0.0 if args.use_beam_search else 1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=args.use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
)
|
||||
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
||||
|
||||
def run_to_completion(profile: bool = False):
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids, sampling_params=sampling_params, use_tqdm=False)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
run_to_completion(profile=False)
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile=False))
|
||||
print(f"Avg latency: {np.mean(latencies)} seconds")
|
||||
print(f"Avg throughput: {args.batch_size*args.output_len/np.mean(latencies)} tokens/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the latency of processing a single batch of " "requests till completion."
|
||||
)
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument("--quantization", "-q", choices=["awq", "squeezellm", None], default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--n", type=int, default=1, help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-iters", type=int, default=3, help="Number of iterations to run.")
|
||||
parser.add_argument("--trust-remote-code", action="store_true", help="trust remote code from huggingface")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
|
||||
help="data type for model weights and activations. "
|
||||
'The "auto" option will use FP16 precision '
|
||||
"for FP32 and FP16 models, and BF16 precision "
|
||||
"for BF16 models.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Loading…
Reference in New Issue
Block a user