mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[Kernels]Update triton kernels into 2.1.0 (#5046)
* update flash-context-attention * adding kernels * fix * reset * add build script * add building process * add llama2 exmaple * add colossal-llama2 test * clean * fall back test setting * fix test file * clean * clean * clean --------- Co-authored-by: cuiqing.li <lixx336@gmail.com>
This commit is contained in:
@@ -28,7 +28,6 @@ def run_llama_test(args):
|
||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
||||
model = model.half()
|
||||
model.config
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
|
81
examples/inference/colossal_llama2_demo.py
Normal file
81
examples/inference/colossal_llama2_demo.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import argparse
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
TPSIZE = 1
|
||||
BATCH_SIZE = 4
|
||||
MAX_INPUT_LEN = 32
|
||||
MAX_OUTPUT_LEN = 128
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
|
||||
|
||||
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': TPSIZE,
|
||||
}])
|
||||
def run_llama_test(test_config, args):
|
||||
|
||||
model_path = args.path
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
|
||||
model = model.half()
|
||||
|
||||
text = ["Introduce London.", "What is the genus of Poodle?"]
|
||||
input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True)
|
||||
|
||||
print(input_ids)
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
|
||||
extra_kwargs={"inference_only": True})
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
outputs = infer_engine.generate(input_ids, **generate_kwargs)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
for o in outputs:
|
||||
output_text = tokenizer.decode(o)
|
||||
print(output_text)
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test(args=args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama(args):
|
||||
spawn(check_llama, args.tp_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-p", "--path", type=str, default = "hpcai-tech/Colossal-LLaMA-2-7b-base", help="Model path")
|
||||
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size")
|
||||
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
|
||||
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
||||
parser.add_argument(
|
||||
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
|
||||
)
|
||||
args = parser.parse_args()
|
||||
test_llama(args)
|
Reference in New Issue
Block a user