mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[Feature] The first PR to Add TP inference engine, kv-cache manager and related kernels for our inference system (#4577)
* [infer] Infer/llama demo (#4503)
* add
* add infer example
* finish
* finish
* stash
* fix
* [Kernels] add inference token attention kernel (#4505)
* add token forward
* fix tests
* fix comments
* add try import triton
* add adapted license
* add tests check
* [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager (#4485)
* added _vllm_rms_norm
* change place
* added tests
* added tests
* modify
* adding kernels
* added tests:
* adding kernels
* modify
* added
* updating kernels
* adding tests
* added tests
* kernel change
* submit
* modify
* added
* edit comments
* change name
* change commnets and fix import
* add
* added
* combine codes (#4509)
* [feature] add KV cache manager for llama & bloom inference (#4495)
* add kv cache memory manager
* add stateinfo during inference
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* file dir change
* [Bug FIx] import llama context ops fix (#4524)
* added _vllm_rms_norm
* change place
* added tests
* added tests
* modify
* adding kernels
* added tests:
* adding kernels
* modify
* added
* updating kernels
* adding tests
* added tests
* kernel change
* submit
* modify
* added
* edit comments
* change name
* change commnets and fix import
* add
* added
* fix
* add ops into init.py
* add
* [Infer] Add TPInferEngine and fix file path (#4532)
* add engine for TP inference
* move file path
* update path
* fix TPInferEngine
* remove unused file
* add engine test demo
* revise TPInferEngine
* fix TPInferEngine, add test
* fix
* Add Inference test for llama (#4508)
* add kv cache memory manager
* add stateinfo during inference
* add
* add infer example
* finish
* finish
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* add inference test for llama
* fix conflict
* feature: add some new features for llama engine
* adapt colossalai triton interface
* Change the parent class of llama policy
* add nvtx
* move llama inference code to tensor_parallel
* fix __init__.py
* rm tensor_parallel
* fix: fix bugs in auto_policy.py
* fix:rm some unused codes
* mv colossalai/tpinference to colossalai/inference/tensor_parallel
* change __init__.py
* save change
* fix engine
* Bug fix: Fix hang
* remove llama_infer_engine.py
---------
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
* [infer] Add Bloom inference policy and replaced methods (#4512)
* add bloom inference methods and policy
* enable pass BatchInferState from model forward
* revise bloom infer layers/policies
* add engine for inference (draft)
* add test for bloom infer
* fix bloom infer policy and flow
* revise bloom test
* fix bloom file path
* remove unused codes
* fix bloom modeling
* fix dir typo
* fix trivial
* fix policy
* clean pr
* trivial fix
* Revert "[infer] Add Bloom inference policy and replaced methods (#4512)" (#4552)
This reverts commit 17cfa57140
.
* [Doc] Add colossal inference doc (#4549)
* create readme
* add readme.md
* fix typos
* [infer] Add Bloom inference policy and replaced methods (#4553)
* add bloom inference methods and policy
* enable pass BatchInferState from model forward
* revise bloom infer layers/policies
* add engine for inference (draft)
* add test for bloom infer
* fix bloom infer policy and flow
* revise bloom test
* fix bloom file path
* remove unused codes
* fix bloom modeling
* fix dir typo
* fix trivial
* fix policy
* clean pr
* trivial fix
* trivial
* Fix Bugs In Llama Model Forward (#4550)
* add kv cache memory manager
* add stateinfo during inference
* add
* add infer example
* finish
* finish
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* add inference test for llama
* fix conflict
* feature: add some new features for llama engine
* adapt colossalai triton interface
* Change the parent class of llama policy
* add nvtx
* move llama inference code to tensor_parallel
* fix __init__.py
* rm tensor_parallel
* fix: fix bugs in auto_policy.py
* fix:rm some unused codes
* mv colossalai/tpinference to colossalai/inference/tensor_parallel
* change __init__.py
* save change
* fix engine
* Bug fix: Fix hang
* remove llama_infer_engine.py
* bug fix: fix bugs about infer_state.is_context_stage
* remove pollcies
* fix: delete unused code
* fix: delete unused code
* remove unused coda
* fix conflict
---------
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
* [doc] add colossal inference fig (#4554)
* create readme
* add readme.md
* fix typos
* upload fig
* [NFC] fix docstring for colossal inference (#4555)
Fix docstring and comments in kv cache manager and bloom modeling
* fix docstring in llama modeling (#4557)
* [Infer] check import vllm (#4559)
* change import vllm
* import apply_rotary_pos_emb
* change import location
* [DOC] add installation req (#4561)
* add installation req
* fix
* slight change
* remove empty
* [Feature] rms-norm transfer into inference llama.py (#4563)
* add installation req
* fix
* slight change
* remove empty
* add rmsnorm polciy
* add
* clean codes
* [infer] Fix tp inference engine (#4564)
* fix engine prepare data
* add engine test
* use bloom for testing
* revise on test
* revise on test
* reset shardformer llama (#4569)
* [infer] Fix engine - tensors on different devices (#4570)
* fix diff device in engine
* [codefactor] Feature/colossal inference (#4579)
* code factors
* remove
* change coding (#4581)
* [doc] complete README of colossal inference (#4585)
* complete fig
* Update README.md
* [doc]update readme (#4586)
* update readme
* Update README.md
* bug fix: fix bus in llama and bloom (#4588)
* [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward (#4592)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* [Kernel]Rmsnorm fix (#4598)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* add triton rmsnorm
* delete vllm kernel flag
* [Bug Fix]Fix bugs in llama (#4601)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* bug fix: remove rotary_positions_ids
---------
Co-authored-by: cuiqing.li <lixx3527@gmail.com>
* [kernel] Add triton layer norm & replace norm for bloom (#4609)
* add layernorm for inference
* add test for layernorm kernel
* add bloom layernorm replacement policy
* trivial: path
* [Infer] Bug fix rotary embedding in llama (#4608)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* [bench] Add bloom inference benchmark (#4621)
* add bloom benchmark
* readme - update benchmark res
* trivial - uncomment for testing (#4622)
* [Infer] add check triton and cuda version for tests (#4627)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* add check triton and cuda
* Update sharder.py (#4629)
* [Inference] Hot fix some bugs and typos (#4632)
* fix
* fix test
* fix conflicts
* [typo]Comments fix (#4633)
* fallback
* fix commnets
* bug fix: fix some bugs in test_llama and test_bloom (#4635)
* [Infer] delete benchmark in tests and fix bug for llama and bloom (#4636)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* add check triton and cuda
* delete benchmark and fix infer bugs
* delete benchmark for tests
* delete useless code
* delete bechmark function in utils
* [Fix] Revise TPInferEngine, inference tests and benchmarks (#4642)
* [Fix] revise TPInferEngine methods and inference tests
* fix llama/bloom infer benchmarks
* fix infer tests
* trivial fix: benchmakrs
* trivial
* trivial: rm print
* modify utils filename for infer ops test (#4657)
* [Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670)
* fix engine funcs
* TPInferEngine: receive shard config in init
* benchmarks: revise TPInferEngine init
* benchmarks: remove pytest decorator
* trivial fix
* use small model for tests
* [NFC] use args for infer benchmarks (#4674)
* revise infer default (#4683)
* [Fix] optimize/shard model in TPInferEngine init (#4684)
* remove using orig model in engine
* revise inference tests
* trivial: rename
---------
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
53
tests/test_infer/_utils.py
Normal file
53
tests/test_infer/_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch import distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import Module
|
||||
from torch.optim import Adam, Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer._utils import getattr_
|
||||
from colossalai.shardformer.policies.auto_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
|
||||
|
||||
def build_model(
|
||||
model_fn,
|
||||
enable_fused_normalization=False,
|
||||
enable_tensor_parallelism=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
):
|
||||
# create new model
|
||||
org_model = model_fn()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
inference_only=True)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
||||
|
||||
def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn):
|
||||
# prepare input
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
# run forward
|
||||
org_output = original_model(**data)
|
||||
org_output = output_transform_fn(org_output)
|
||||
|
||||
shard_output = sharded_model(**data)
|
||||
shard_output = output_transform_fn(shard_output)
|
||||
|
||||
return org_output, shard_output
|
58
tests/test_infer/test_bloom_infer.py
Normal file
58
tests/test_infer/test_bloom_infer.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
TP_SIZE = 2
|
||||
MAX_BATCH_SIZE = 4
|
||||
MAX_INPUT_LEN = 16
|
||||
MAX_OUTPUT_LEN = 32
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
|
||||
|
||||
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': TP_SIZE,
|
||||
}])
|
||||
def run(test_config):
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom_for_causal_lm')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
orig_model = model_fn()
|
||||
orig_model = orig_model.half()
|
||||
data = data_gen_fn()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
|
||||
inference_only=True)
|
||||
infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
generate_kwargs = dict(do_sample=False)
|
||||
outputs = infer_engine.generate(data, **generate_kwargs)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
def check_bloom(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bloom_infer():
|
||||
spawn(check_bloom, TP_SIZE)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bloom_infer()
|
94
tests/test_infer/test_infer_engine.py
Normal file
94
tests/test_infer/test_infer_engine.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from itertools import accumulate
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 2
|
||||
MAX_BATCH_SIZE = 4
|
||||
MAX_INPUT_LEN = 16
|
||||
MAX_OUTPUT_LEN = 8
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
|
||||
|
||||
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': TP_SIZE,
|
||||
}])
|
||||
def run(test_config):
|
||||
model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
|
||||
model = BloomForCausalLM(model_config)
|
||||
model = model.half()
|
||||
model.to(torch.cuda.current_device())
|
||||
|
||||
# 1. check TPInferEngine init and model optimization
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
|
||||
inference_only=True)
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
assert infer_engine.cache_manager is not None
|
||||
assert infer_engine.tp_size == TP_SIZE
|
||||
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE
|
||||
|
||||
# 2. check data preparation
|
||||
input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970],
|
||||
[80540, 15473, 3331, 11970], [80540, 15473]]
|
||||
batch_size = len(input_ids_list)
|
||||
max_seq_len = max(len(li) for li in input_ids_list)
|
||||
attention_mask = [[0] * max_seq_len for _ in range(batch_size)]
|
||||
for i, li in enumerate(input_ids_list):
|
||||
attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))]
|
||||
data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
|
||||
inputs_batch_encoding = BatchEncoding(data=data)
|
||||
seq_lengths = [len(li) for li in input_ids_list]
|
||||
start_loc = list(accumulate([0] + seq_lengths[:-1]))
|
||||
seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
|
||||
start_loc = torch.tensor(start_loc, dtype=torch.int32)
|
||||
# input token id list as inputs
|
||||
batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
|
||||
# BatchEncoding as inputs
|
||||
batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)
|
||||
|
||||
assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
|
||||
assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len)
|
||||
|
||||
# The following tests are discarded for now, and will be reused after all features are added
|
||||
# assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
|
||||
# assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
|
||||
# assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
|
||||
# assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)
|
||||
|
||||
# 3. check optimized model generate
|
||||
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))
|
||||
generate_kwargs = dict(do_sample=False)
|
||||
infer_engine.generate(input_ids, **generate_kwargs)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_engine(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_engine():
|
||||
spawn(check_engine, TP_SIZE)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_engine()
|
61
tests/test_infer/test_kvcache_manager.py
Normal file
61
tests/test_infer/test_kvcache_manager.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
from packaging import version
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.inference.tensor_parallel import MemoryManager
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
BATCH_SIZE = 4
|
||||
INPUT_LEN = 16
|
||||
OUTPUT_LEN = 8
|
||||
LAYER_NUM = 4
|
||||
HEAD_NUM = 32
|
||||
HEAD_DIM = 128
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
|
||||
|
||||
def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
disable_existing_loggers()
|
||||
|
||||
size = batch_size * (input_len + output_len)
|
||||
kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank)
|
||||
key_buffers = kvcache_manager.key_buffer
|
||||
value_buffers = kvcache_manager.value_buffer
|
||||
assert len(key_buffers) == len(value_buffers) == layer_num
|
||||
assert key_buffers[0].shape == value_buffers[0].shape
|
||||
# required size exceeds the maximum allocated size
|
||||
invalid_locs = kvcache_manager.alloc_contiguous(size + 1)
|
||||
assert invalid_locs is None
|
||||
# for prefill stage, allocation via alloc and alloc_contiguous should be the same
|
||||
total_token_prefill = batch_size * input_len
|
||||
prefill_locs = kvcache_manager.alloc(total_token_prefill)
|
||||
kvcache_manager.free_all()
|
||||
prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0]
|
||||
assert torch.equal(prefill_locs, prefill_locs_contiguous)
|
||||
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
|
||||
kvcache_manager.alloc_contiguous(batch_size)
|
||||
assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False)
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_cache_manager_dist():
|
||||
spawn(create_cache_manager,
|
||||
4,
|
||||
batch_size=BATCH_SIZE,
|
||||
input_len=INPUT_LEN,
|
||||
output_len=OUTPUT_LEN,
|
||||
layer_num=LAYER_NUM,
|
||||
head_num=HEAD_NUM,
|
||||
head_dim=HEAD_DIM)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cache_manager_dist()
|
84
tests/test_infer/test_llama_infer.py
Normal file
84
tests/test_infer/test_llama_infer.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
TPSIZE = 2
|
||||
BATCH_SIZE = 8
|
||||
MAX_INPUT_LEN = 12
|
||||
MAX_OUTPUT_LEN = 100
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000):
|
||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
||||
if not hasattr(self.config, "rope_scaling"):
|
||||
rope_scaling_factor = 1.0
|
||||
else:
|
||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
||||
if hasattr(self.config, "max_sequence_length"):
|
||||
max_seq_len = self.config.max_sequence_length
|
||||
elif hasattr(self.config, "max_position_embeddings"):
|
||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
||||
else:
|
||||
max_seq_len = 2048 * rope_scaling_factor
|
||||
base = float(base)
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
|
||||
self.config.head_dim_))
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
||||
return
|
||||
|
||||
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': TPSIZE,
|
||||
}])
|
||||
def run_llama_test(test_config):
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama_for_casual_lm')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
orig_model = model_fn()
|
||||
init_to_get_rotary(orig_model.model, base=10000)
|
||||
orig_model = orig_model.half()
|
||||
data = data_gen_fn()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
|
||||
inference_only=True)
|
||||
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
generate_kwargs = dict(do_sample=False)
|
||||
outputs = infer_engine.generate(data, **generate_kwargs)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, TPSIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
60
tests/test_infer_ops/cuda/test_vllm_rmsnorm.py
Normal file
60
tests/test_infer_ops/cuda/test_vllm_rmsnorm.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
try:
|
||||
from vllm import layernorm_ops
|
||||
rms_norm = layernorm_ops.rms_norm
|
||||
HAS_VLLM_KERNERL = True
|
||||
except:
|
||||
print("please install vllm kernels to install rmsnorm")
|
||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
||||
HAS_VLLM_KERNERL = False
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon):
|
||||
x = hidden_states
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
|
||||
def test_rmsnorm():
|
||||
data = torch.randn((1024, 64), dtype=torch.float16, device="cuda")
|
||||
hg_rms = LlamaRMSNorm(64)
|
||||
hg_rms = hg_rms.half().cuda()
|
||||
out_torch = hg_rms(data)
|
||||
out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon)
|
||||
|
||||
check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5)
|
||||
assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward"
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rmsnorm()
|
156
tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py
Normal file
156
tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import pytest
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half
|
||||
|
||||
try:
|
||||
from vllm import pos_encoding_ops
|
||||
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
|
||||
HAS_VLLM_KERNERL = True
|
||||
except:
|
||||
print("fall back to original rotary_embedding_neox of huggingface")
|
||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
||||
HAS_VLLM_KERNERL = False
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class RefRotaryEmbeddingNeox(nn.Module):
|
||||
"""Reference implementation of the GPT-NeoX style rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.rotary_dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
||||
t = torch.arange(max_position_embeddings).float()
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
||||
self.register_buffer("cos_cached", cos, persistent=False)
|
||||
self.register_buffer("sin_cached", sin, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
query_rot = query_rot.transpose(0, 1)
|
||||
key_rot = key_rot.transpose(0, 1)
|
||||
cos = F.embedding(positions, self.cos_cached)
|
||||
sin = F.embedding(positions, self.sin_cached)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
||||
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
# Output query/key shape: [num_tokens, num_tokens, head_size]
|
||||
return query, key
|
||||
|
||||
def run_rotary_embedding_neox(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
max_position: int,
|
||||
rotary_dim: int,
|
||||
dtype: torch.dtype,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
|
||||
query = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
||||
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
out_key = key.clone()
|
||||
rotary_embedding_neox(
|
||||
positions,
|
||||
out_query,
|
||||
out_key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
||||
dim=rotary_dim,
|
||||
max_position_embeddings=max_position,
|
||||
base=base,
|
||||
).to(dtype=dtype, device='cuda')
|
||||
ref_query, ref_key = ref_rotary_embedding(
|
||||
positions,
|
||||
query.view(num_tokens, num_heads, head_size),
|
||||
key.view(num_tokens, num_heads, head_size),
|
||||
)
|
||||
ref_query = ref_query.view(num_tokens, num_heads * head_size)
|
||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
||||
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
||||
|
||||
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
|
||||
def test_rotary_embedding():
|
||||
run_rotary_embedding_neox(
|
||||
num_tokens=1024,
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
max_position=8192,
|
||||
rotary_dim=64,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_embedding()
|
28
tests/test_infer_ops/triton/kernel_utils.py
Normal file
28
tests/test_infer_ops/triton/kernel_utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
'''
|
||||
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
|
||||
'''
|
||||
xq = xq.view(bs, seqlen, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
|
||||
mask[mask == 0.] = -100000000.0
|
||||
mask = mask.repeat(bs, num_head, 1, 1)
|
||||
keys = xk
|
||||
values = xv
|
||||
xq = xq.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2)
|
||||
values = values.transpose(1, 2)
|
||||
sm_scale = 1 / math.sqrt(head_dim)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
|
||||
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)
|
||||
|
||||
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
|
||||
return output
|
54
tests/test_infer_ops/triton/test_bloom_context_attention.py
Normal file
54
tests/test_infer_ops/triton/test_bloom_context_attention.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from colossalai.kernel.triton import bloom_context_attn_fwd
|
||||
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||||
reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_bloom_context_attention():
|
||||
bs = 4
|
||||
head_num = 8
|
||||
seq_len = 1024
|
||||
head_dim = 64
|
||||
|
||||
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
|
||||
max_input_len = seq_len
|
||||
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
|
||||
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
|
||||
|
||||
for i in range(bs):
|
||||
b_start[i] = i * seq_len
|
||||
b_len[i] = seq_len
|
||||
|
||||
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
|
||||
bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi)
|
||||
|
||||
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
|
||||
|
||||
assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3,
|
||||
atol=1e-2), "outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bloom_context_attention()
|
39
tests/test_infer_ops/triton/test_copy_kv_dest.py
Normal file
39
tests/test_infer_ops/triton/test_copy_kv_dest.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||||
reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_kv_cache_copy_op():
|
||||
|
||||
B_NTX = 32 * 2048
|
||||
head_num = 8
|
||||
head_dim = 64
|
||||
|
||||
cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
|
||||
dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32)
|
||||
|
||||
dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
|
||||
|
||||
copy_kv_cache_to_dest(cache, dest_index, dest_data)
|
||||
|
||||
assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3,
|
||||
atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_kv_cache_copy_op()
|
44
tests/test_infer_ops/triton/test_layernorm_triton.py
Normal file
44
tests/test_infer_ops/triton/test_layernorm_triton.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.kernel.triton import layer_norm
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from colossalai.kernel.triton.fused_layernorm import _layer_norm_fwd_fused
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||||
reason="triton requires cuda version to be higher than 11.4")
|
||||
@parameterize('M', [2, 4, 8, 16])
|
||||
@parameterize('N', [64, 128])
|
||||
def test_layer_norm(M, N):
|
||||
dtype = torch.float16
|
||||
eps = 1e-5
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1],)
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda')
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda')
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
|
||||
y_triton = layer_norm(x, weight, bias, eps)
|
||||
y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
|
||||
|
||||
assert y_triton.shape == y_torch.shape
|
||||
assert y_triton.dtype == y_torch.dtype
|
||||
print("max delta: ", torch.max(torch.abs(y_triton - y_torch)))
|
||||
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_layer_norm()
|
53
tests/test_infer_ops/triton/test_llama_context_attention.py
Normal file
53
tests/test_infer_ops/triton/test_llama_context_attention.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd
|
||||
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||||
reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_llama_context_attention():
|
||||
bs = 4
|
||||
head_num = 8
|
||||
seq_len = 1024
|
||||
head_dim = 64
|
||||
|
||||
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
|
||||
max_input_len = seq_len
|
||||
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
|
||||
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
|
||||
|
||||
for i in range(bs):
|
||||
b_start[i] = i * seq_len
|
||||
b_len[i] = seq_len
|
||||
|
||||
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
|
||||
|
||||
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
|
||||
|
||||
assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3,
|
||||
atol=1e-3), "outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama_context_attention()
|
56
tests/test_infer_ops/triton/test_rotary_embedding.py
Normal file
56
tests/test_infer_ops/triton/test_rotary_embedding.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
def torch_rotary_emb(x, cos, sin):
|
||||
seq_len, h, dim = x.shape
|
||||
x0 = x[:, :, 0:dim // 2]
|
||||
x1 = x[:, :, dim // 2:dim]
|
||||
cos = cos.view((seq_len, 1, dim // 2))
|
||||
sin = sin.view((seq_len, 1, dim // 2))
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
return torch.cat((o0, o1), dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||||
reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_rotary_emb():
|
||||
SEQ_LEN = 1
|
||||
HEAD_NUM = 32
|
||||
HEAD_DIM = 128
|
||||
dtype = torch.half
|
||||
# create data
|
||||
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda')
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda')
|
||||
# forward pass
|
||||
y_torch = torch_rotary_emb(x, cos, sin)
|
||||
rotary_embedding_fwd(x, cos, sin)
|
||||
y_triton = x
|
||||
# compare
|
||||
assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb()
|
@@ -4,12 +4,11 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.kernel.triton.ops import self_attention_compute_using_triton
|
||||
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton
|
||||
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -17,7 +16,7 @@ except ImportError:
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_qkv_matmul():
|
||||
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
|
||||
scale = 1.2
|
||||
@@ -106,7 +105,7 @@ def self_attention_compute_using_torch(qkv,
|
||||
|
||||
return res.view(batches, -1, d_model), score_output, softmax_output
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_self_atttention_test():
|
||||
|
||||
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
|
@@ -3,11 +3,19 @@ from packaging import version
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.kernel.triton.ops import softmax
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from colossalai.kernel.triton.softmax import softmax
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_softmax_op():
|
||||
data_samples = [
|
||||
torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32),
|
72
tests/test_infer_ops/triton/test_token_attn_1.py
Normal file
72
tests/test_infer_ops/triton/test_token_attn_1.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
def torch_attn(xq, xk, bs, seqlen, num_head, head_dim):
|
||||
xq = xq.view(bs, 1, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
keys = xk
|
||||
xq = xq.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2)
|
||||
scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(
|
||||
num_head, -1)
|
||||
return scores
|
||||
|
||||
|
||||
def torch_attn_1(xq, xk, seqlen, num_head, head_dim):
|
||||
xq = xq.view(1, num_head, head_dim)
|
||||
xk = xk.view(seqlen, num_head, head_dim)
|
||||
logics = torch.sum(xq * xk, dim=-1, keepdim=False)
|
||||
|
||||
logics = logics.transpose(0, 1) / math.sqrt(head_dim)
|
||||
return logics
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||||
reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_attn_1():
|
||||
import time
|
||||
|
||||
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
||||
k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
||||
attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")
|
||||
|
||||
b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
|
||||
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
for i in range(batch_size):
|
||||
kv_cache_start_loc[i] = i * seq_len
|
||||
kv_cache_seq_len[i] = seq_len
|
||||
b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
|
||||
|
||||
torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze()
|
||||
o = attn_out.squeeze()
|
||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_attn_1()
|
61
tests/test_infer_ops/triton/test_token_attn_2.py
Normal file
61
tests/test_infer_ops/triton/test_token_attn_2.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
def torch_attn(V, P, bs, seqlen, num_head, head_dim):
|
||||
V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2)
|
||||
P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1)
|
||||
attn_out = torch.matmul(P, V)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||||
reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_token_attn_2():
|
||||
import time
|
||||
|
||||
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
|
||||
dtype = torch.float16
|
||||
|
||||
V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
|
||||
Prob = torch.empty(
|
||||
(head_num, batch_size * seq_len), dtype=dtype,
|
||||
device="cuda").normal_(mean=0.4, std=0.2).reshape(head_num, batch_size,
|
||||
seq_len).softmax(-1).reshape(head_num, batch_size * seq_len)
|
||||
attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
|
||||
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
|
||||
for i in range(batch_size):
|
||||
kv_cache_start_loc[i] = i * seq_len
|
||||
kv_cache_seq_len[i] = seq_len
|
||||
kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
|
||||
|
||||
torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze()
|
||||
o = attn_out
|
||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_token_attn_2()
|
67
tests/test_infer_ops/triton/test_token_attn_fwd.py
Normal file
67
tests/test_infer_ops/triton/test_token_attn_fwd.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
xq = xq.view(bs, 1, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
|
||||
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
|
||||
prob = torch.softmax(logics, dim=1)
|
||||
prob = prob.view(bs, seqlen, num_head, 1)
|
||||
|
||||
return torch.sum(prob * xv, dim=1, keepdim=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||||
reason="triton requires cuda version to be higher than 11.4")
|
||||
def test():
|
||||
|
||||
Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128
|
||||
dtype = torch.float16
|
||||
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
||||
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
|
||||
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
|
||||
o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
|
||||
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
|
||||
|
||||
max_kv_cache_len = seq_len
|
||||
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
|
||||
|
||||
kv_cache_seq_len[:] = seq_len
|
||||
kv_cache_start_loc[0] = 0
|
||||
kv_cache_start_loc[1] = seq_len
|
||||
kv_cache_start_loc[2] = 2 * seq_len
|
||||
kv_cache_start_loc[3] = 3 * seq_len
|
||||
|
||||
for i in range(Z):
|
||||
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi)
|
||||
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
|
||||
|
||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
48
tests/test_infer_ops/triton/test_token_softmax.py
Normal file
48
tests/test_infer_ops/triton/test_token_softmax.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||||
reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_softmax():
|
||||
|
||||
import torch
|
||||
|
||||
batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
|
||||
ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
|
||||
|
||||
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
for i in range(batch_size):
|
||||
kv_cache_start_loc[i] = i * seq_len
|
||||
kv_cache_seq_len[i] = seq_len
|
||||
|
||||
token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len)
|
||||
|
||||
torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len)
|
||||
o = ProbOut
|
||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_softmax()
|
Reference in New Issue
Block a user