mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Refactor] Integrated some lightllm kernels into token-attention (#4946)
* add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
@@ -3,7 +3,6 @@ import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
@@ -16,6 +15,7 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||
torch.cuda.empty_cache()
|
||||
# trim warmup queries
|
||||
latency_set = list(latency_set)
|
||||
latency_set = latency_set[warmup:]
|
||||
@@ -38,24 +38,29 @@ def run_llama_test(args):
|
||||
max_batch_size = args.batch_size
|
||||
max_input_len = args.input_len
|
||||
max_output_len = args.output_len
|
||||
args.test_mode
|
||||
|
||||
print("max_batch_size : " + str(max_batch_size))
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
|
||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
||||
model = model.half()
|
||||
model_config = model.config
|
||||
model.config
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
|
||||
}
|
||||
|
||||
iters = 10
|
||||
times = []
|
||||
prefill_times = []
|
||||
|
||||
warmup = 3
|
||||
|
||||
for i in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
@@ -65,17 +70,39 @@ def run_llama_test(args):
|
||||
end = time.time()
|
||||
out_len = outputs.shape[1]
|
||||
print("generation time {} s".format(str(end - start)))
|
||||
print(out_len - max_input_len)
|
||||
prefill_times.append((end - start) / (out_len - max_input_len))
|
||||
|
||||
prefill_times = prefill_times[warmup:]
|
||||
prefill_time_avg = sum(prefill_times) / len(prefill_times)
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
|
||||
times = []
|
||||
decoder_times = []
|
||||
for i in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
out_len = outputs.shape[1]
|
||||
print("generation time {} s".format(str(end - start)))
|
||||
print(out_len - max_input_len)
|
||||
times.append((end - start) / (out_len - max_input_len))
|
||||
if args.test_mode == "decoder_test":
|
||||
decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1))
|
||||
|
||||
print("outputs, ", len(outputs))
|
||||
print_perf_stats(times, model_config, max_batch_size)
|
||||
times = times[warmup:]
|
||||
latency = sum(times) / len(times)
|
||||
print("total process latency is : " + str(latency) + " s")
|
||||
print("total throughput is : " + str(1 / latency * max_batch_size))
|
||||
|
||||
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
|
||||
with record_function("model_inference"):
|
||||
torch.cuda.synchronize()
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
if args.test_mode == "decoder_test":
|
||||
decoder_times = decoder_times[warmup:]
|
||||
latency = sum(decoder_times) / len(decoder_times)
|
||||
|
||||
print("decoder process latency is : " + str(latency) + " s")
|
||||
print("decoder throughput is : " + str(1 / latency * max_batch_size))
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port, args):
|
||||
@@ -95,8 +122,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
|
||||
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
|
||||
parser.add_argument("--input_len", type=int, default=256, help="Maximum input length")
|
||||
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
||||
parser.add_argument(
|
||||
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
Reference in New Issue
Block a user