[Feature] The first PR to Add TP inference engine, kv-cache manager and related kernels for our inference system (#4577)

* [infer] Infer/llama demo (#4503)

* add

* add infer example

* finish

* finish

* stash

* fix

* [Kernels]  add inference token attention kernel (#4505)

* add token forward

* fix tests

* fix comments

* add try import triton

* add adapted license

* add tests check

* [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager  (#4485)

* added _vllm_rms_norm

* change place

* added tests

* added tests

* modify

* adding kernels

* added tests:

* adding kernels

* modify

* added

* updating kernels

* adding tests

* added tests

* kernel change

* submit

* modify

* added

* edit comments

* change name

* change commnets and fix import

* add

* added

* combine codes (#4509)

* [feature] add KV cache manager for llama & bloom inference (#4495)

* add kv cache memory manager

* add stateinfo during inference

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* file dir change

* [Bug FIx] import llama context ops fix (#4524)

* added _vllm_rms_norm

* change place

* added tests

* added tests

* modify

* adding kernels

* added tests:

* adding kernels

* modify

* added

* updating kernels

* adding tests

* added tests

* kernel change

* submit

* modify

* added

* edit comments

* change name

* change commnets and fix import

* add

* added

* fix

* add ops into init.py

* add

* [Infer] Add TPInferEngine and fix file path (#4532)

* add engine for TP inference

* move file path

* update path

* fix TPInferEngine

* remove unused file

* add engine test demo

* revise TPInferEngine

* fix TPInferEngine, add test

* fix

* Add Inference test for llama (#4508)

* add kv cache memory manager

* add stateinfo during inference

* add

* add infer example

* finish

* finish

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* add inference test for llama

* fix conflict

* feature: add some new features for llama engine

* adapt colossalai triton interface

* Change the parent class of llama  policy

* add nvtx

* move llama inference code to tensor_parallel

* fix __init__.py

* rm tensor_parallel

* fix: fix bugs in auto_policy.py

* fix:rm some unused codes

* mv colossalai/tpinference to colossalai/inference/tensor_parallel

* change __init__.py

* save change

* fix engine

* Bug fix: Fix hang

* remove llama_infer_engine.py

---------

Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [infer] Add Bloom inference policy and replaced methods (#4512)

* add bloom inference methods and policy

* enable pass BatchInferState from model forward

* revise bloom infer layers/policies

* add engine for inference (draft)

* add test for bloom infer

* fix bloom infer policy and flow

* revise bloom test

* fix bloom file path

* remove unused codes

* fix bloom modeling

* fix dir typo

* fix trivial

* fix policy

* clean pr

* trivial fix

* Revert "[infer] Add Bloom inference policy and replaced methods (#4512)" (#4552)

This reverts commit 17cfa57140.

* [Doc] Add colossal inference doc (#4549)

* create readme

* add readme.md

* fix typos

* [infer] Add Bloom inference policy and replaced methods (#4553)

* add bloom inference methods and policy

* enable pass BatchInferState from model forward

* revise bloom infer layers/policies

* add engine for inference (draft)

* add test for bloom infer

* fix bloom infer policy and flow

* revise bloom test

* fix bloom file path

* remove unused codes

* fix bloom modeling

* fix dir typo

* fix trivial

* fix policy

* clean pr

* trivial fix

* trivial

* Fix Bugs In Llama Model Forward (#4550)

* add kv cache memory manager

* add stateinfo during inference

* add

* add infer example

* finish

* finish

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* add inference test for llama

* fix conflict

* feature: add some new features for llama engine

* adapt colossalai triton interface

* Change the parent class of llama  policy

* add nvtx

* move llama inference code to tensor_parallel

* fix __init__.py

* rm tensor_parallel

* fix: fix bugs in auto_policy.py

* fix:rm some unused codes

* mv colossalai/tpinference to colossalai/inference/tensor_parallel

* change __init__.py

* save change

* fix engine

* Bug fix: Fix hang

* remove llama_infer_engine.py

* bug fix: fix bugs about infer_state.is_context_stage

* remove pollcies

* fix: delete unused code

* fix: delete unused code

* remove unused coda

* fix conflict

---------

Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [doc] add colossal inference fig (#4554)

* create readme

* add readme.md

* fix typos

* upload fig

* [NFC] fix docstring for colossal inference (#4555)

Fix docstring and comments in kv cache manager and bloom modeling

* fix docstring in llama modeling (#4557)

* [Infer] check import vllm (#4559)

* change import vllm

* import apply_rotary_pos_emb

* change import location

* [DOC] add installation req (#4561)

* add installation req

* fix

* slight change

* remove empty

* [Feature] rms-norm transfer into inference llama.py  (#4563)

* add installation req

* fix

* slight change

* remove empty

* add rmsnorm polciy

* add

* clean codes

* [infer] Fix tp inference engine (#4564)

* fix engine prepare data

* add engine test

* use bloom for testing

* revise on test

* revise on test

* reset shardformer llama (#4569)

* [infer] Fix engine - tensors on different devices (#4570)


* fix diff device in engine

* [codefactor] Feature/colossal inference (#4579)

* code factors

* remove

* change coding (#4581)

* [doc] complete README of colossal inference (#4585)

* complete fig

* Update README.md

* [doc]update readme (#4586)

* update readme

* Update README.md

* bug fix: fix bus in llama and bloom (#4588)

* [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward  (#4592)

* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* [Kernel]Rmsnorm fix (#4598)

* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* add triton rmsnorm

* delete vllm kernel flag

* [Bug Fix]Fix bugs in llama (#4601)

* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* bug fix: remove rotary_positions_ids

---------

Co-authored-by: cuiqing.li <lixx3527@gmail.com>

* [kernel] Add triton layer norm & replace norm for bloom (#4609)

* add layernorm for inference

* add test for layernorm kernel

* add bloom layernorm replacement policy

* trivial: path

* [Infer] Bug fix rotary embedding in llama (#4608)

* fix rotary embedding

* delete print

* fix init seq len bug

* rename pytest

* add benchmark for llama

* refactor codes

* delete useless code

* [bench] Add bloom inference benchmark (#4621)

* add bloom benchmark

* readme - update benchmark res

* trivial - uncomment for testing (#4622)

* [Infer] add check triton and cuda version for tests (#4627)

* fix rotary embedding

* delete print

* fix init seq len bug

* rename pytest

* add benchmark for llama

* refactor codes

* delete useless code

* add check triton and cuda

* Update sharder.py (#4629)

* [Inference] Hot fix some bugs and typos (#4632)

* fix

* fix test

* fix conflicts

* [typo]Comments fix (#4633)

* fallback

* fix commnets

* bug fix: fix some bugs in test_llama and test_bloom (#4635)

* [Infer] delete benchmark in tests and fix bug for llama and bloom (#4636)

* fix rotary embedding

* delete print

* fix init seq len bug

* rename pytest

* add benchmark for llama

* refactor codes

* delete useless code

* add check triton and cuda

* delete benchmark and fix infer bugs

* delete benchmark for tests

* delete useless code

* delete bechmark function in utils

* [Fix] Revise TPInferEngine, inference tests and benchmarks (#4642)

* [Fix] revise TPInferEngine methods and inference tests

* fix llama/bloom infer benchmarks

* fix infer tests

* trivial fix: benchmakrs

* trivial

* trivial: rm print

* modify utils filename for infer ops test (#4657)

* [Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670)

* fix engine funcs

* TPInferEngine: receive shard config in init

* benchmarks: revise TPInferEngine init

* benchmarks: remove pytest decorator

* trivial fix

* use small model for tests

* [NFC] use args for infer benchmarks (#4674)

* revise infer default (#4683)

* [Fix] optimize/shard model in TPInferEngine init (#4684)

* remove using orig model in engine

* revise inference tests

* trivial: rename

---------

Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
Cuiqing Li
2023-09-12 01:22:56 +08:00
committed by GitHub
parent eedaa3e1ef
commit bce0f16702
49 changed files with 3980 additions and 137 deletions

View File

@@ -0,0 +1,53 @@
import copy
import torch
import torch.distributed as dist
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
def build_model(
model_fn,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
enable_flash_attention=False,
enable_jit_fused=False,
):
# create new model
org_model = model_fn()
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
inference_only=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()
def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
return org_output, shard_output

View File

@@ -0,0 +1,58 @@
import os
import pytest
import torch
from packaging import version
import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
TP_SIZE = 2
MAX_BATCH_SIZE = 4
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 32
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
@parameterize('test_config', [{
'tp_size': TP_SIZE,
}])
def run(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom_for_causal_lm')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
orig_model = model_fn()
orig_model = orig_model.half()
data = data_gen_fn()
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(data, **generate_kwargs)
assert outputs is not None
def check_bloom(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom_infer():
spawn(check_bloom, TP_SIZE)
if __name__ == '__main__':
test_bloom_infer()

View File

@@ -0,0 +1,94 @@
from itertools import accumulate
import pytest
import torch
import torch.nn as nn
from packaging import version
from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers.tokenization_utils_base import BatchEncoding
import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
TP_SIZE = 2
MAX_BATCH_SIZE = 4
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 8
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
@parameterize('test_config', [{
'tp_size': TP_SIZE,
}])
def run(test_config):
model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
model = BloomForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())
# 1. check TPInferEngine init and model optimization
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
assert infer_engine.cache_manager is not None
assert infer_engine.tp_size == TP_SIZE
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE
# 2. check data preparation
input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970],
[80540, 15473, 3331, 11970], [80540, 15473]]
batch_size = len(input_ids_list)
max_seq_len = max(len(li) for li in input_ids_list)
attention_mask = [[0] * max_seq_len for _ in range(batch_size)]
for i, li in enumerate(input_ids_list):
attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))]
data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
inputs_batch_encoding = BatchEncoding(data=data)
seq_lengths = [len(li) for li in input_ids_list]
start_loc = list(accumulate([0] + seq_lengths[:-1]))
seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
start_loc = torch.tensor(start_loc, dtype=torch.int32)
# input token id list as inputs
batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
# BatchEncoding as inputs
batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)
assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len)
# The following tests are discarded for now, and will be reused after all features are added
# assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
# assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
# assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
# assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)
# 3. check optimized model generate
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))
generate_kwargs = dict(do_sample=False)
infer_engine.generate(input_ids, **generate_kwargs)
torch.cuda.empty_cache()
def check_engine(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_engine():
spawn(check_engine, TP_SIZE)
if __name__ == '__main__':
test_engine()

View File

@@ -0,0 +1,61 @@
import os
from packaging import version
import pytest
import torch
from colossalai.inference.tensor_parallel import MemoryManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
BATCH_SIZE = 4
INPUT_LEN = 16
OUTPUT_LEN = 8
LAYER_NUM = 4
HEAD_NUM = 32
HEAD_DIM = 128
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
disable_existing_loggers()
size = batch_size * (input_len + output_len)
kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank)
key_buffers = kvcache_manager.key_buffer
value_buffers = kvcache_manager.value_buffer
assert len(key_buffers) == len(value_buffers) == layer_num
assert key_buffers[0].shape == value_buffers[0].shape
# required size exceeds the maximum allocated size
invalid_locs = kvcache_manager.alloc_contiguous(size + 1)
assert invalid_locs is None
# for prefill stage, allocation via alloc and alloc_contiguous should be the same
total_token_prefill = batch_size * input_len
prefill_locs = kvcache_manager.alloc(total_token_prefill)
kvcache_manager.free_all()
prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0]
assert torch.equal(prefill_locs, prefill_locs_contiguous)
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
kvcache_manager.alloc_contiguous(batch_size)
assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False)
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager_dist():
spawn(create_cache_manager,
4,
batch_size=BATCH_SIZE,
input_len=INPUT_LEN,
output_len=OUTPUT_LEN,
layer_num=LAYER_NUM,
head_num=HEAD_NUM,
head_dim=HEAD_DIM)
if __name__ == '__main__':
test_cache_manager_dist()

View File

@@ -0,0 +1,84 @@
import os
import warnings
import pytest
import torch
from packaging import version
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
TPSIZE = 2
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
self.config.head_dim_))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return
@parameterize('test_config', [{
'tp_size': TPSIZE,
}])
def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama_for_casual_lm')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
orig_model = model_fn()
init_to_get_rotary(orig_model.model, base=10000)
orig_model = orig_model.half()
data = data_gen_fn()
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(data, **generate_kwargs)
assert outputs is not None
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, TPSIZE)
if __name__ == "__main__":
test_llama()

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import pytest
import numpy as np
from packaging import version
import torch
from torch import nn
from torch.nn import functional as F
try:
from vllm import layernorm_ops
rms_norm = layernorm_ops.rms_norm
HAS_VLLM_KERNERL = True
except:
print("please install vllm kernels to install rmsnorm")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
HAS_VLLM_KERNERL = False
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
weight,
variance_epsilon,
)
return out
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
def test_rmsnorm():
data = torch.randn((1024, 64), dtype=torch.float16, device="cuda")
hg_rms = LlamaRMSNorm(64)
hg_rms = hg_rms.half().cuda()
out_torch = hg_rms(data)
out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon)
check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5)
assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward"
if __name__ == "__main__":
test_rmsnorm()

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half
try:
from vllm import pos_encoding_ops
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
HAS_VLLM_KERNERL = False
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RefRotaryEmbeddingNeox(nn.Module):
"""Reference implementation of the GPT-NeoX style rotary embedding."""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: int = 10000,
) -> None:
super().__init__()
self.rotary_dim = dim
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=inv_freq.dtype)
sin = emb.sin().to(dtype=inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1)
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query_rot = query_rot.transpose(0, 1).contiguous()
key_rot = key_rot.transpose(0, 1).contiguous()
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return query, key
def run_rotary_embedding_neox(
num_tokens: int,
num_heads: int,
head_size: int,
max_position: int,
rotary_dim: int,
dtype: torch.dtype,
base: int = 10000,
) -> None:
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
query = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
key = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
# Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
out_key = key.clone()
rotary_embedding_neox(
positions,
out_query,
out_key,
head_size,
cos_sin_cache,
)
# Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbeddingNeox(
dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
).to(dtype=dtype, device='cuda')
ref_query, ref_key = ref_rotary_embedding(
positions,
query.view(num_tokens, num_heads, head_size),
key.view(num_tokens, num_heads, head_size),
)
ref_query = ref_query.view(num_tokens, num_heads * head_size)
ref_key = ref_key.view(num_tokens, num_heads * head_size)
# Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
def test_rotary_embedding():
run_rotary_embedding_neox(
num_tokens=1024,
num_heads=8,
head_size=64,
max_position=8192,
rotary_dim=64,
dtype=torch.float16,
)
if __name__ == "__main__":
test_rotary_embedding()

View File

@@ -0,0 +1,28 @@
import math
import numpy as np
import torch
from torch.nn import functional as F
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
'''
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
'''
xq = xq.view(bs, seqlen, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
mask[mask == 0.] = -100000000.0
mask = mask.repeat(bs, num_head, 1, 1)
keys = xk
values = xv
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
sm_scale = 1 / math.sqrt(head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
return output

View File

@@ -0,0 +1,54 @@
import math
import pytest
import torch
from packaging import version
from torch import nn
from torch.nn import functional as F
try:
import triton
import triton.language as tl
from colossalai.kernel.triton import bloom_context_attn_fwd
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_bloom_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
max_input_len = seq_len
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3,
atol=1e-2), "outputs from triton and torch are not matched"
if __name__ == "__main__":
test_bloom_context_attention()

View File

@@ -0,0 +1,39 @@
import pytest
import torch
from packaging import version
from torch import nn
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_kv_cache_copy_op():
B_NTX = 32 * 2048
head_num = 8
head_dim = 64
cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32)
dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
copy_kv_cache_to_dest(cache, dest_index, dest_data)
assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3,
atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched"
if __name__ == "__main__":
test_kv_cache_copy_op()

View File

@@ -0,0 +1,44 @@
import pytest
import torch
from packaging import version
from colossalai.kernel.triton import layer_norm
from colossalai.testing.utils import parameterize
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.fused_layernorm import _layer_norm_fwd_fused
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
@parameterize('M', [2, 4, 8, 16])
@parameterize('N', [64, 128])
def test_layer_norm(M, N):
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.rand(w_shape, dtype=dtype, device='cuda')
bias = torch.rand(w_shape, dtype=dtype, device='cuda')
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
y_triton = layer_norm(x, weight, bias, eps)
y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
assert y_triton.shape == y_torch.shape
assert y_triton.dtype == y_torch.dtype
print("max delta: ", torch.max(torch.abs(y_triton - y_torch)))
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_layer_norm()

View File

@@ -0,0 +1,53 @@
import math
import pytest
import torch
from packaging import version
from torch import nn
from torch.nn import functional as F
try:
import triton
import triton.language as tl
from colossalai.kernel.triton import llama_context_attn_fwd
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_llama_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
max_input_len = seq_len
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3,
atol=1e-3), "outputs from triton and torch are not matched"
if __name__ == "__main__":
test_llama_context_attention()

View File

@@ -0,0 +1,56 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import time
import pytest
import torch
from packaging import version
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
def torch_rotary_emb(x, cos, sin):
seq_len, h, dim = x.shape
x0 = x[:, :, 0:dim // 2]
x1 = x[:, :, dim // 2:dim]
cos = cos.view((seq_len, 1, dim // 2))
sin = sin.view((seq_len, 1, dim // 2))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
return torch.cat((o0, o1), dim=-1)
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_rotary_emb():
SEQ_LEN = 1
HEAD_NUM = 32
HEAD_DIM = 128
dtype = torch.half
# create data
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda')
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda')
# forward pass
y_torch = torch_rotary_emb(x, cos, sin)
rotary_embedding_fwd(x, cos, sin)
y_triton = x
# compare
assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_rotary_emb()

View File

@@ -4,12 +4,11 @@ import torch
from torch import nn
import torch.nn.functional as F
from colossalai.kernel.triton.ops import self_attention_compute_using_triton
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@@ -17,7 +16,7 @@ except ImportError:
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_qkv_matmul():
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
scale = 1.2
@@ -106,7 +105,7 @@ def self_attention_compute_using_torch(qkv,
return res.view(batches, -1, d_model), score_output, softmax_output
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_self_atttention_test():
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)

View File

@@ -3,11 +3,19 @@ from packaging import version
import torch
from torch import nn
from colossalai.kernel.triton.ops import softmax
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.softmax import softmax
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_softmax_op():
data_samples = [
torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32),

View File

@@ -0,0 +1,72 @@
import math
import pytest
import torch
from packaging import version
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
def torch_attn(xq, xk, bs, seqlen, num_head, head_dim):
xq = xq.view(bs, 1, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
keys = xk
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(
num_head, -1)
return scores
def torch_attn_1(xq, xk, seqlen, num_head, head_dim):
xq = xq.view(1, num_head, head_dim)
xk = xk.view(seqlen, num_head, head_dim)
logics = torch.sum(xq * xk, dim=-1, keepdim=False)
logics = logics.transpose(0, 1) / math.sqrt(head_dim)
return logics
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_attn_1():
import time
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
dtype = torch.float16
q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")
b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
for i in range(batch_size):
kv_cache_start_loc[i] = i * seq_len
kv_cache_seq_len[i] = seq_len
b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze()
o = attn_out.squeeze()
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_attn_1()

View File

@@ -0,0 +1,61 @@
import math
import pytest
import torch
from packaging import version
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
def torch_attn(V, P, bs, seqlen, num_head, head_dim):
V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2)
P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1)
attn_out = torch.matmul(P, V)
return attn_out
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_token_attn_2():
import time
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
dtype = torch.float16
V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
Prob = torch.empty(
(head_num, batch_size * seq_len), dtype=dtype,
device="cuda").normal_(mean=0.4, std=0.2).reshape(head_num, batch_size,
seq_len).softmax(-1).reshape(head_num, batch_size * seq_len)
attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda")
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
for i in range(batch_size):
kv_cache_start_loc[i] = i * seq_len
kv_cache_seq_len[i] = seq_len
kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze()
o = attn_out
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_token_attn_2()

View File

@@ -0,0 +1,67 @@
import time
import pytest
import torch
from packaging import version
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
xq = xq.view(bs, 1, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
prob = torch.softmax(logics, dim=1)
prob = prob.view(bs, seqlen, num_head, 1)
return torch.sum(prob * xv, dim=1, keepdim=False)
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test():
Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128
dtype = torch.float16
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
max_kv_cache_len = seq_len
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
kv_cache_seq_len[:] = seq_len
kv_cache_start_loc[0] = 0
kv_cache_start_loc[1] = seq_len
kv_cache_start_loc[2] = 2 * seq_len
kv_cache_start_loc[3] = 3 * seq_len
for i in range(Z):
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi)
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,48 @@
import pytest
import torch
from packaging import version
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_softmax():
import torch
batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128
dtype = torch.float16
Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
for i in range(batch_size):
kv_cache_start_loc[i] = i * seq_len
kv_cache_seq_len[i] = seq_len
token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len)
torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len)
o = ProbOut
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_softmax()