Merge remote-tracking branch 'origin/main' into prefetch

This commit is contained in:
hxwang
2024-05-23 15:49:33 +00:00
238 changed files with 20778 additions and 9787 deletions

View File

@@ -1,144 +0,0 @@
import pytest
import torch
from packaging import version
try:
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
try:
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
from exllama_kernels import prepare_buffers, set_tuning_params
from colossalai.inference.quant.gptq import CaiQuantLinear
HAS_AUTO_GPTQ = True
except:
HAS_AUTO_GPTQ = False
print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ")
import warnings
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
max_inner_outer_dim = 1
max_input_len = 1
max_dq_buffer_size = 1
gptq_temp_dq_buffer = None
gptq_temp_state_buffer = None
def init_buffer(cai_linear, use_act_order=False):
global max_dq_buffer_size
global max_input_len
global max_dq_buffer_size
global max_inner_outer_dim
global gptq_temp_dq_buffer
global gptq_temp_state_buffer
max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8)
if use_act_order:
max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures)
if use_act_order:
max_input_len = 4096
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
gptq_temp_state_buffer = torch.zeros(
(max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
)
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
)
def test_gptq_linear():
infeature = 1024
outfeature = 1024
group_size = 128
wbits = 4
inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device())
batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device())
device = torch.device("cuda:0")
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits)
linear = linear_class(
bits=4,
group_size=group_size,
infeatures=infeature,
outfeatures=outfeature,
bias=False,
)
torch.manual_seed(42)
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear = linear.to(device)
cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True)
cai_linear.qweight.data.copy_(linear.qweight)
cai_linear.scales = cai_linear.scales + 0.002
cai_linear = cai_linear.to(device)
linear = autogptq_post_init(linear, use_act_order=False)
max_inner_outer_dim = max(infeature, outfeature)
max_dq_buffer_size = linear.infeatures * linear.outfeatures
max_input_len = 2048
buffers = {
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
}
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
with torch.no_grad():
gptq_out = linear(inps)
batch_gptq_out = linear(batch_inps)
torch.cuda.synchronize()
cai_out = cai_linear(inps)
torch.cuda.synchronize()
batch_cai_out = cai_linear(batch_inps)
torch.cuda.synchronize()
assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01)
assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01)
if __name__ == "__main__":
test_gptq_linear()

View File

0
tests/test_infer/_utils.py Normal file → Executable file
View File

View File

@@ -0,0 +1,80 @@
import asyncio
from dataclasses import dataclass
import pytest
from colossalai.inference.core.async_engine import AsyncInferenceEngine
@dataclass
class MockSequence:
request_id: int
class MockEngine:
def __init__(self):
self.step_calls = 0
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
async def async_step(self):
self.step_calls += 1
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
def add_single_request(self, **kwargs):
del kwargs
self.add_request_calls += 1
def generate(self, request_id):
self.request_id = request_id
def stop_generating(self):
self.request_id = None
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1
class MockAsyncInferenceEngine(AsyncInferenceEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncInferenceEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request(1, "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request(2, "", None)
engine.engine.generate(2)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2
await asyncio.sleep(0)
assert engine.engine.step_calls == 3
engine.engine.stop_generating()
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await engine.add_request(3, "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5

View File

@@ -0,0 +1,68 @@
import pytest
from colossalai.inference.core.async_engine import Tracer
from colossalai.inference.struct import Sequence
class SampleEvent:
def __init__(self):
self.flag = False
def set(self):
self.flag = True
def clear(self):
self.flag = False
def test_request_tracer():
tracker = Tracer()
tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 1
assert not stream_1.finished
stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3
assert not stream_2.finished
assert not stream_3.finished
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request(1)
assert not tracker.new_requests_event.flag
tracker.abort_request(1)
new = tracker.get_new_requests()
assert not new
stream_4 = tracker.add_request(4)
tracker.abort_request(4)
assert tracker.new_requests_event.flag
new = tracker.get_new_requests()
assert not new
assert stream_4.finished
stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 5
assert stream_2.finished
assert not stream_5.finished
if __name__ == "__main__":
test_request_tracer()

View File

@@ -0,0 +1,144 @@
import torch
from transformers.models.llama import LlamaConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize
logger = get_dist_logger(__name__)
@parameterize(
"test_config",
[
{
"hidden_size": 128,
"num_attention_heads": 4,
"num_layers": 2,
"block_size": 4,
"max_batch_size": 4,
"max_input_len": 32,
"max_output_len": 8,
"dtype": torch.float16,
"tp_size": 1,
}
],
)
def test_bucket(test_config):
hidden_size = test_config.pop("hidden_size")
num_heads = test_config.pop("num_attention_heads")
num_layers = test_config.pop("num_layers")
model_config = LlamaConfig(
hidden_size=hidden_size,
num_hidden_layers=num_layers,
num_attention_heads=num_heads,
)
inference_config = InferenceConfig(**test_config)
# Just for testing usage. Don't create multiple cache_manager on the same device.
cache_manager = KVCacheManager(inference_config, model_config)
cache_manager_copy = KVCacheManager(inference_config, model_config)
seq_lens = [19, 20, 27]
seq1 = Sequence(
request_id=0,
prompt="", # Dummy for testing usage
input_token_id=list(range(seq_lens[0])),
block_size=4,
sample_params=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=10,
)
seq2 = Sequence(
request_id=1,
prompt="", # Dummy for testing usage
input_token_id=list(range(seq_lens[1])),
block_size=4,
sample_params=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=10,
)
seq3 = Sequence(
request_id=2,
prompt="", # Dummy for testing usage
input_token_id=list(range(seq_lens[2])),
block_size=4,
sample_params=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=10,
)
block_size = test_config["block_size"]
max_batch_size = test_config["max_batch_size"]
max_length = test_config["max_input_len"] + test_config["max_output_len"]
assert max_batch_size >= 2, "max_batch_size should be greater than 1"
bb = BatchBucket(
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
)
bb_copy = BatchBucket(
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
)
block_tables = bb.add_seqs([seq1, seq2])
logger.debug(f"bb information: {bb}")
assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence)
assert torch.all(block_tables < 0), "Initialized block_tables should be negative values"
cache_manager.allocate_context_from_block_tables(block_tables, bb.seq_lengths[: bb.current_batch_size])
bb_copy.add_seqs(
[seq1, seq2], alloc_block_tables_fn=cache_manager_copy.allocate_context_from_block_tables
) # This is just for testing usage. Don't add the same sequence to different buckets.
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
max_batch_size - bb.current_batch_size
)
assert torch.equal(bb.block_tables, bb_copy.block_tables)
bb.append_batch_tokens(torch.tensor([99, 99]))
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
max_batch_size - bb.current_batch_size
)
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size)
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
max_batch_size - bb.current_batch_size
)
bb.append_batch_tokens(torch.tensor([99, 99]))
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size)
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
max_batch_size - bb.current_batch_size
)
bb.pop_seq_update_batch(0, free_block_table_fn=cache_manager.free_block_table)
assert bb.seq_lengths.tolist() == [bb.seqs_li[0].sentence_len] + [0] * (max_batch_size - bb.current_batch_size)
assert bb.is_compact
bb2 = BatchBucket(
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
)
block_tables = bb2.add_seqs([seq3])
cache_manager.allocate_context_from_block_tables(block_tables, bb2.seq_lengths[: bb2.current_batch_size])
unmerged_ids = bb.merge(bb2)
assert not unmerged_ids
assert bb.is_compact
assert bb2.is_compact
assert bb.current_batch_size == 2
assert bb2.current_batch_size == 0
bb.clear(cache_manager.free_block_tables)
assert bb.current_batch_size == 0
assert bb.is_compact
assert bb.seq_lengths.tolist() == [0] * max_batch_size
assert torch.all(bb.block_tables < 0)
if __name__ == "__main__":
test_bucket()

View File

@@ -0,0 +1,46 @@
import pytest
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_config_and_inference():
config = InferenceConfig()
assert config.max_batch_size == 8
sequence = Sequence(
request_id=1,
prompt="abc",
input_token_id=[1, 2, 3],
block_size=16,
sample_params=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
)
sequence.mark_running()
assert sequence.status == RequestStatus.RUNNING
sequence.recycle()
assert sequence.status == RequestStatus.RECYCLED
assert sequence.sentence_len == 3
assert sequence.input_len == 3
assert sequence.output_len == 0
assert sequence.check_finish() == False
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_config_and_inference()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_config_and_inference():
spawn(run_dist, 1)
if __name__ == "__main__":
test_config_and_inference()

View File

@@ -0,0 +1,103 @@
import random
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def generate_inputs(num_sequences, min_length, max_length):
sequences = []
for _ in range(num_sequences):
length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item()
# generating randomly lengthed sequences
sequence = torch.randint(10, 30000, size=(length,))
sequences.append(sequence)
return sequences
@parameterize(
"test_config",
[
{
"max_batch_size": 8,
"max_output_len": 512,
"max_input_len": 64,
"do_sample": False,
}
],
)
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
setup_seed(20)
max_batch_size = test_config["max_batch_size"]
max_input_len = test_config["max_input_len"]
max_output_len = test_config["max_output_len"]
do_sample = test_config["do_sample"]
top_p = 0.5
top_k = 50
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
model = model.eval()
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)
if use_engine:
inference_config = InferenceConfig(
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == max_output_len
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=max_output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert len(outputs) == 10 * max_batch_size
@parameterize("prompt_template", [None, "llama"])
def check_continuous_batching(prompt_template):
check_inference_engine(use_engine=True, prompt_template=prompt_template)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_continuous_batching()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_continuous_batching():
spawn(run_dist, 1)
if __name__ == "__main__":
test_continuous_batching()

View File

@@ -0,0 +1,96 @@
import random
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_cuda_graph=False, batch_size=32):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
)
.cuda()
.half()
)
model = model.eval()
prompts_token_ids = []
for i in range(batch_size):
prompts_token_ids.append(
np.random.randint(low=0, high=100, size=random.randint(1, max(1024 // batch_size, 32))).tolist()
)
input_len = 1024
output_len = 128
do_sample = False
top_p = 0.5
top_k = 50
if use_cuda_graph:
inference_config = InferenceConfig(
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_kernel=False,
use_cuda_graph=True,
block_size=16,
)
else:
inference_config = InferenceConfig(
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_kernel=False,
use_cuda_graph=False,
block_size=16,
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config)
return outputs
def check_output_consistency(batch_size):
cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size)
naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size)
for s1, s2 in zip(cuda_graph_output, naive_model_output):
assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}"
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency(32)
check_output_consistency(64)
check_output_consistency(128)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_cuda_graph_infer():
spawn(run_dist, 1)
if __name__ == "__main__":
test_cuda_graph_infer()

View File

@@ -0,0 +1,74 @@
import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.inference.spec.drafter import Drafter
from colossalai.utils import get_current_device
NUM_LAYERS = 1
MAX_LEN = 100
SPEC_NUM = 5
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
def test_drafter(tokenizer, spec_num: int):
torch.manual_seed(123)
device = get_current_device()
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
drafter = Drafter(drafter_model, tokenizer, device=device)
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num)
past_kv_length = input_ids.size(1) + spec_num - 1
assert out.speculated_length == spec_num
assert out.next_tokens.shape == (spec_num,)
assert out.logits.shape == (spec_num, len(tokenizer))
assert out.past_key_values[0][0].size(2) == past_kv_length
reject_num = max(0, spec_num - 1)
trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num)
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
def test_spec_dec(tokenizer):
spec_num = SPEC_NUM
device = get_current_device()
tokenizer.pad_token = tokenizer.eos_token
# Dummy config for Glide Model
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=NUM_LAYERS,
)
drafter_model = GlideLlamaForCausalLM(glide_config)
assert hasattr(drafter_model, "model")
assert hasattr(drafter_model.model, "layers")
for _, layer in enumerate(drafter_model.model.layers):
assert hasattr(layer, "cross_attn")
# Init the Drafter by providing the sharded drafter model
drafter = Drafter(drafter_model, tokenizer, device=device, dtype=torch.float16)
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num, past_key_values=None)
if __name__ == "__main__":
dummy_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
test_drafter(dummy_tokenizer, spec_num=SPEC_NUM)
test_spec_dec(dummy_tokenizer)

View File

@@ -1,121 +0,0 @@
import importlib.util
import pytest
import torch
import torch.distributed as dist
import transformers
from packaging import version
import colossalai
from colossalai.inference import InferenceEngine
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
HAS_LIGHTLLM_KERNEL = True
if importlib.util.find_spec("lightllm") is None:
HAS_LIGHTLLM_KERNEL = False
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.BloomForCausalLM(
transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4)
)
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
)
output = engine.generate(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
@parameterize("tp_size", [1])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_tp_pp_inference(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()
def check_tp_or_pp_inference(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
run_pipeline_inference_test()
def check_single_inference(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_single_inference_test
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_tp_pp_inference, nprocs=4)
spawn(check_tp_or_pp_inference, nprocs=2)
spawn(check_single_inference, nprocs=1)
if __name__ == "__main__":
test_pipeline_inference()

View File

@@ -1,129 +0,0 @@
import importlib.util
import pytest
import torch
import torch.distributed as dist
from packaging import version
import colossalai
from colossalai.inference import InferenceEngine
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
HAS_LIGHTLLM_KERNEL = True
if importlib.util.find_spec("lightllm") is None:
HAS_LIGHTLLM_KERNEL = False
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
chatglm_config = ChatGLMConfig(
num_layers=2,
vocab_size=20000,
use_cache=True,
multi_query_attention=True,
multi_query_group_num=2,
num_attention_heads=8,
hidden_size=1024,
)
model = ChatGLMForConditionalGeneration(chatglm_config)
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
)
output = engine.generate(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
@parameterize("tp_size", [1])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_tp_pp_inference(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()
def check_tp_or_pp_inference(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
run_pipeline_inference_test()
def check_single_inference(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_single_inference_test
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_tp_pp_inference, nprocs=4)
spawn(check_tp_or_pp_inference, nprocs=2)
spawn(check_single_inference, nprocs=1)
if __name__ == "__main__":
test_pipeline_inference()

View File

@@ -1,126 +0,0 @@
import importlib.util
import pytest
import torch
import torch.distributed as dist
import transformers
from packaging import version
import colossalai
from colossalai.inference import InferenceEngine
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
import importlib.util
HAS_LIGHTLLM_KERNEL = True
if importlib.util.find_spec("lightllm") is None:
HAS_LIGHTLLM_KERNEL = False
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
)
)
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
)
output = engine.generate(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
@parameterize("tp_size", [1])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_tp_pp_inference(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()
def check_tp_or_pp_inference(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
run_pipeline_inference_test()
def check_single_inference(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_single_inference_test
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_tp_pp_inference, nprocs=4)
spawn(check_tp_or_pp_inference, nprocs=2)
spawn(check_single_inference, nprocs=1)
if __name__ == "__main__":
test_pipeline_inference()

View File

@@ -0,0 +1,208 @@
import random
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_engine=False, prompt_template=None, do_sample=True, policy=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000,
hidden_size=512,
intermediate_size=1536,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=16,
)
).cuda()
model = model.eval()
inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,",
]
output_len = 38
do_sample = do_sample
top_p = 0.5
top_k = 50
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
dtype="fp32",
use_cuda_kernel=True,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
dtype="fp32",
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs)
return result_list[0]
def check_spec_dec(num_layers, max_length):
torch.manual_seed(123)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# Dummy configs for testing
toy_config = LlamaConfig(num_hidden_layers=num_layers)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
large_config = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=8,
num_key_value_heads=32,
max_position_embeddings=2048,
)
large_config.pad_token_id = tokenizer.eos_token_id
main_model = LlamaForCausalLM(large_config)
inference_config = InferenceConfig(
dtype="fp16",
micro_batch_size=1,
max_batch_size=1,
max_input_len=128,
max_output_len=128,
prefill_ratio=1.2,
block_size=16,
)
engine = InferenceEngine(main_model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.disable_spec_dec()
engine.clear_spec_dec()
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
max_new_tokens = max_length - dummy_inputs.size(1)
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
# test GLIDE model
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=num_layers,
)
glide_model = GlideLlamaForCausalLM(glide_config)
engine.enable_spec_dec(glide_model, use_glide_drafter=True)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.clear_spec_dec()
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
@pytest.mark.largedist
@parameterize("prompt_template", [None, "llama"])
@parameterize("do_sample", [False])
@rerun_if_address_is_in_use()
def test_tp_engine(prompt_template, do_sample):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": NoPaddingLlamaModelInferPolicy(),
}
kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
@pytest.mark.largedist
@parameterize("num_layers", [1])
@parameterize("max_length", [64])
@rerun_if_address_is_in_use()
def test_spec_dec(num_layers, max_length):
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)
if __name__ == "__main__":
test_tp_engine()
test_spec_dec()

View File

@@ -0,0 +1,57 @@
import random
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
@pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!")
@pytest.mark.parametrize("num_heads", [8])
@pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256])
@pytest.mark.parametrize("block_size", [8, 16, 32])
@pytest.mark.parametrize("num_blocks", [1024, 10000])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
@pytest.mark.parametrize("seed", [0])
@torch.inference_mode()
def test_fp8_conversion(
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = get_current_device()
low = -224.0
high = 224.0
shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
inference_ops.convert_fp8(cache, cache_fp8)
converted_cache = torch.empty_like(cache)
inference_ops.convert_fp8(cache_fp8, converted_cache)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
if __name__ == "__main__":
test_fp8_conversion(8, 64, 8, 1024, torch.half, 0)

View File

@@ -0,0 +1,334 @@
from itertools import product
import numpy as np
import pytest
import torch
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v3,
generate_caches_and_block_tables_vllm,
torch_attn_ref,
)
q_len = 1
PARTITION_SIZE = 512
def prepare_data(
BATCH_SIZE: int,
HEAD_SIZE: int,
NUM_ATTN_HEADS: int,
NUM_KV_HEADS: int,
MAX_SEQ_LEN: int,
dtype=torch.float16,
device="cuda",
):
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
# otherwise generate random context lengths.
# returns
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE]
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE]
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device)
num_tokens = torch.sum(kv_lengths).item()
q_size = (BATCH_SIZE, q_len, NUM_ATTN_HEADS, HEAD_SIZE)
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2)
return q, k_unpad, v_unpad, kv_lengths
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32, 256, 512])
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()
try:
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
else:
alibi_slopes = None
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask
if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]
mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
)
exp_sums = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device)
max_logits = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)
except torch.cuda.OutOfMemoryError:
pytest.skip("Required GPU memory is larger than capacity.")
inference_ops.flash_decoding_attention(
output,
q.squeeze(2),
k_cache,
v_cache,
kv_seq_lengths,
block_tables,
BLOCK_SIZE,
max_seq_len_across_batch,
mid_output,
exp_sums,
max_logits,
alibi_slopes,
sm_scale,
)
# The alibi may introduce relatively large errors
if use_alibi_slopes:
rtol = 1e0
try:
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
except AssertionError:
if MAX_NUM_BLOCKS_PER_SEQ >= 256:
pytest.skip("Long sequence length introduce precision error.")
else:
raise
try:
from vllm._C import ops as vllm_ops # noqa
HAS_VLLM = True
except ImportError:
HAS_VLLM = False
print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm")
@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm")
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_vllm_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
kv_scale = 1.0
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask
if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]
else:
alibi_slopes = None
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)
vllm_ops.paged_attention_v1(
output,
q.squeeze(2),
k_cache,
v_cache,
NUM_KV_HEADS,
sm_scale,
block_tables,
kv_seq_lengths,
BLOCK_SIZE,
max_seq_len_across_batch,
alibi_slopes,
"auto",
kv_scale,
)
# The alibi may introduce relatively large errors
if use_alibi_slopes:
rtol = 1e0
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
if __name__ == "__main__":
BATCH_SIZE = [1, 4, 7, 32]
BLOCK_SIZE = [8, 16, 32]
MAX_NUM_BLOCKS_PER_SEQ = [1, 8, 32]
HEAD_SIZE = [64, 128]
NUM_ATTN_HEADS = [16]
KV_GROUP_NUM = [1, 2, 16]
DTYPE = [torch.float16, torch.float32]
test_combinations = list(
product(BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, DTYPE)
)
for (
batch_size,
block_size,
max_num_blocks_per_seq,
head_size,
num_attn_heads,
kv_group_num,
dtype,
) in test_combinations:
test_flash_decoding_attention(
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True
)

View File

@@ -0,0 +1,53 @@
import numpy as np
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
inference_ops = InferenceOpsLoader().load()
def numpy_equal(x, y):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_equal(x_numpy, y_numpy)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
@pytest.mark.parametrize("HEAD_DIM", [64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_get_cos_and_sin(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda").to(torch.int32)
max_seq_len_in_batch = lengths.max()
# prefill
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
cos = torch.zeros_like(cos_ref)
sin = torch.zeros_like(sin_ref)
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, True)
numpy_equal(cos, cos_ref)
numpy_equal(sin, sin_ref)
# decoding
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
cos = torch.zeros_like(ncos_ref)
sin = torch.zeros_like(nsin_ref)
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, False)
numpy_equal(cos, ncos_ref)
numpy_equal(sin, nsin_ref)
if __name__ == "__main__":
test_get_cos_and_sin(16, 4096, 256, torch.float16)

View File

@@ -0,0 +1,157 @@
import pytest
import torch
import torch.nn.functional as F
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v3,
mock_alloc_single_token,
)
inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 72
def prepare_data(
bsz,
num_kv_heads,
block_size,
max_num_blocks_per_seq,
context_lengths,
device="cuda",
dtype=torch.float16,
):
num_tokens = torch.sum(context_lengths).item()
max_seq_len_in_batch = context_lengths.max()
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
k_cache = torch.zeros_like(k_cache_ref)
v_cache = torch.zeros_like(v_cache_ref)
return key, value, k_cache, v_cache, cu_seqlens, block_tables, max_seq_len_in_batch, k_cache_ref, v_cache_ref
def run_decode_copy_kv_to_caches(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
n = 1
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float32
device = get_current_device()
assert max_seq_len > n, "max_seq_len must be greater than n"
past_kv_seq_lengths = (
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
if same_context_len
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
)
key, value, k_cache, v_cache, _, block_tables, _, _, _ = prepare_data(
bsz, num_kv_heads, block_size, max_num_blocks_per_seq, past_kv_seq_lengths, device, dtype
)
new_k = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
new_v = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
for _ in range(n):
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
past_kv_seq_lengths += 1
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables)
past_kv_seq_len = past_kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
k_source = new_k.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
k_target = k_target.reshape(v_target.shape)
v_source = new_v.squeeze()
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
def run_context_copy_kv_to_cache(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
max_seq_len = max_num_blocks_per_seq * block_size
dtype = torch.float16
device = get_current_device()
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
(
key,
value,
k_cache,
v_cache,
cu_seqlens,
block_tables,
max_seq_len_in_batch,
k_cache_ref,
v_cache_ref,
) = prepare_data(bsz, num_kv_heads, block_size, max_num_blocks_per_seq, context_lengths, device, dtype)
inference_ops.context_kv_cache_memcpy(
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch
)
assert torch.equal(k_cache, k_cache_ref)
assert torch.equal(v_cache, v_cache_ref)
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_kv_cache_memcopy(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
run_context_copy_kv_to_cache(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
run_decode_copy_kv_to_caches(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
if __name__ == "__main__":
test_kv_cache_memcopy(4, 32, 8, 16, True)

View File

@@ -0,0 +1,51 @@
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("M", [2, 4, 8, 16])
@pytest.mark.parametrize("N", [64, 128, 512, 5120])
def test_rms_layernorm(M: int, N: int):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
device = get_current_device()
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device=device)
residual = torch.rand(x_shape, dtype=dtype, device=device)
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_cuda = torch.empty_like(x)
inference_ops.rms_layernorm(y_cuda, x, weight, eps)
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
y_cuda = x
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
if __name__ == "__main__":
test_rms_layernorm(16, 5120)

View File

@@ -0,0 +1,130 @@
import numpy as np
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.kernel_loader import InferenceOpsLoader
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_kernels.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
from tests.test_infer.test_kernels.triton.test_rotary_embdding_unpad import torch_rotary_emb
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("SEQ_LEN", [64])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("K_H", [16, 32])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
torch.manual_seed(10)
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
cos_2 = cos[:, : D // 2]
sin_2 = sin[:, : D // 2]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
assert torch.allclose(embd_x0, embd_stimulated_x)
# create data
block_size = 32
max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size
q_shape = (TOTAL_TOKENS, H, D)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (TOTAL_TOKENS, K_H, D)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, D // x, block_size, x)
v_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
)
new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda")
new_q = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
new_v = torch.randn_like(new_k)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
new_q_copy = new_q.clone()
new_k_copy = new_k.clone()
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
new_q_fp16 = new_q.clone()
new_k_fp16 = new_k.clone()
high_precision_cos = cos[:BATCH_SIZE].to(torch.float32)
high_precision_sin = sin[:BATCH_SIZE].to(torch.float32)
high_precision_q = new_q.to(torch.float32)
high_precision_k = new_k.to(torch.float32)
q_ref = torch_rotary_emb(high_precision_q, high_precision_cos, high_precision_sin).to(torch.float16)
k_ref = torch_rotary_emb(high_precision_k, high_precision_cos, high_precision_sin).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables, True
)
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin, True)
past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :].squeeze()
k_source = new_k_copy.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
k_target = k_target.reshape(v_target.shape)
v_source = new_v.squeeze()
numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)
numpy_allclose(k_target, k_ref, rtol=rtol, atol=atol)
numpy_allclose(new_q_copy, q_ref, rtol=rtol, atol=atol)
numpy_allclose(new_k_copy, k_ref, rtol=rtol, atol=atol)
assert k_target.shape == k_source.shape
numpy_allclose(k_target, k_source, rtol=rtol, atol=atol)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if dtype == torch.float16:
# After testing cuda fp16 high_precision, it was found to have higher precision than torch fp16. Therefore, the threshold here has been relaxed to pass the test.
rtol = 1e-3
atol = 1e-1
inference_ops.rotary_embedding(new_q_fp16, new_k_fp16, cos, sin, False)
numpy_allclose(new_q_copy, new_q_fp16, rtol=rtol, atol=atol)
numpy_allclose(new_k_copy, new_k_fp16, rtol=rtol, atol=atol)
if __name__ == "__main__":
test_rotary_emb(16, 64, 32, 16, 128, torch.float16)

View File

@@ -0,0 +1,33 @@
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("SHAPE_X", [2])
@pytest.mark.parametrize("SHAPE_Y", [64])
@pytest.mark.parametrize("SHAPE_Z", [11008])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
torch.manual_seed(5)
device = get_current_device()
ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device)
origin_input = ref_input.clone()
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
ref_out = act_out * ref_input[1]
origin_out = inference_ops.silu_and_mul(origin_input)
if dtype == torch.float32:
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)
else:
assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
test_silu_and_mul(2, 64, 11008, torch.float32)
test_silu_and_mul(2, 64, 11008, torch.float16)

View File

@@ -0,0 +1,348 @@
from typing import Tuple
import torch
from torch.nn import functional as F
# This function is adapted from src/transformers/models/llama/modeling_llama.py
# in huggingface transformers repository
# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (bsz, num_key_value_heads, seq_len, head_dim) to (bsz, num_attention_heads, seq_len, head_dim)
"""
if n_rep == 1:
return hidden_states
bsz, num_key_value_heads, seq_len, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
assert q_len <= kv_len
causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1)
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
for i in range(bsz):
cur_seq_len = kv_lengths[i].item()
assert cur_seq_len <= kv_len
padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
padding_mask[:, :, -q_len:, -q_len:] += causal_mask
return padding_mask
# Attention calculation adapted from HuggingFace transformers repository
# src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
def torch_attn_ref(
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
k: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
v: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
attention_mask: torch.Tensor, # [bsz, 1, q_len, kv_len]
bsz: int,
q_len: int,
kv_len: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
) -> torch.Tensor:
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim
# repeat kv for GQA and MQA
# k/v won't change if kv_group_num is 1
assert num_heads % num_kv_heads == 0, "Number of heads is not multiple of kv heads"
kv_group_num = num_heads // num_kv_heads
k = repeat_kv(k, kv_group_num)
v = repeat_kv(v, kv_group_num)
qk = torch.matmul(q, k.transpose(2, 3))
attn_scores = qk / (head_dim**0.5)
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
out = torch.matmul(attn_weights, v)
if out.size() != (bsz, num_heads, q_len, head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" f" {out.size()}"
)
out = out.transpose(1, 2).contiguous()
out = out.view(-1, out.size(-2), out.size(-1))
# out [bsz * q_len, num_heads, head_dim]
return out
def mock_alloc_block_table_and_kvcache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
k_cache[block_id, :, :, :allocated_locs] = k_block
v_cache[block_id, :, :, :allocated_locs] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_block_table_and_kvcache_v2(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
k_cache[block_id, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :allocated_locs, :] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_block_table_and_kvcache_v3(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
_, num_kv_heads, head_dim = k.shape
x = 16 // torch.tensor([], dtype=k.dtype).element_size()
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
# [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x]
k_block = (
k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :]
.reshape(allocated_locs, num_kv_heads, head_dim // x, x)
.permute(1, 2, 0, 3)
)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
k_cache[block_id, :, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :allocated_locs, :] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_block_table_and_kvcache_vllm(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
_, num_kv_heads, head_dim = k.shape
x = 16 // torch.tensor([], dtype=k.dtype).element_size()
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
# [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x]
k_block = (
k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :]
.reshape(allocated_locs, num_kv_heads, head_dim // x, x)
.permute(1, 2, 0, 3)
)
# [block_size, num_kv_heads, head_dim]->[num_kv_heads, head_dim, block_size]
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
k_cache[block_id, :, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :, :allocated_locs] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
# Allocate 1 token on the block table for each seqs in block tables.
# It won't change provided context_lengths.
# Consider max_block_id as the last physical block allocated
# NOTE It assumes all the blocks preceding this block have been allocated
max_block_id = torch.max(block_tables).item()
# the indices on each block table representing the cache block to be allocated one more token
alloc_local_block_indices = context_lengths // block_size
# offsets of the token to be allocated on the target block (for each seq)
alloc_block_offsets = context_lengths % block_size
require_new_block = alloc_block_offsets == 0
new_block_ids = torch.arange(
max_block_id + 1,
max_block_id + 1 + require_new_block.sum(),
dtype=block_tables.dtype,
device=block_tables.device,
)
if new_block_ids.numel():
new_block_alloc_local_indices = alloc_local_block_indices[require_new_block]
block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids
def generate_caches_and_block_tables(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_v2(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_v3(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_vllm(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def convert_kv_unpad_to_padded(
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
) -> torch.Tensor:
# Rebuild (batched) k/v with padding to be used by torch attention
# input k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
# returns k/v padded [bsz, num_kv_heads, max_seq_len, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k_unpad.dtype, device=k_unpad.device)
prev_len_sum = 0
for i, seq_len in enumerate(kv_seq_lengths.tolist()):
# left-side padding
k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len]
prev_len_sum += seq_len
k_torch = k_torch.transpose(1, 2)
return k_torch

View File

@@ -0,0 +1,179 @@
import pytest
import torch
from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 32
def _fill_with_neg_inf(t):
return t.float().fill_(float("-inf")).type_as(t)
# alibi mask calculation adapted from https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py
def generate_alibi_mask(slopes, num_heads, max_seq_len, device):
token_position = torch.arange(max_seq_len, device=device) - max_seq_len + 1
token_position = token_position.unsqueeze(0).unsqueeze(0).expand(num_heads, -1, -1)
diag = torch.diag(token_position[0])
token_position = token_position - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * token_position
alibi = alibi.view(num_heads, 1, max_seq_len)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_seq_len, max_seq_len], device=device)), 1)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
return alibi_mask
def torch_attn_unpad(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
slopes: torch.Tensor = None,
):
# Process sequence one by one and concatenate them together.
# q,k,v [num_tokens(sum(context_lengths)), num_heads, head_dim]
assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor"
_, num_heads, head_dim = q.shape
out_torch = []
start_idx = 0
for seq_i in range(len(context_lengths)):
end_idx = start_idx + context_lengths[seq_i].item()
seq_len = end_idx - start_idx
mask = torch.tril(torch.ones(1, 1, seq_len, seq_len), diagonal=0).to(device=q.device)
mask[mask == 0.0] = float("-inf")
if slopes is not None:
alibi_mask = generate_alibi_mask(slopes, num_heads, seq_len, q.device)
mask = mask + alibi_mask
torch_attn_ref_out = torch_attn_ref(
q[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
k[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
v[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
mask,
1, # set bsz as 1 as we're processing sequence one by one
seq_len,
seq_len,
num_heads,
num_kv_heads,
head_dim,
)
out_torch.append(torch_attn_ref_out.squeeze(0))
start_idx = end_idx
return torch.cat(out_torch, dim=0)
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("bsz", [7, 32])
@pytest.mark.parametrize("block_size", [16, 32])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16])
@pytest.mark.parametrize("num_attn_heads", [16])
@pytest.mark.parametrize("kv_group_num", [1, 4])
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_context_attention(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_attn_heads: int,
kv_group_num: int,
same_context_len: bool,
use_alibi_slopes: bool,
use_new_kcache_layout: bool,
):
if use_new_kcache_layout and use_alibi_slopes:
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
# the code (alibi kernel) will be refactored later to avoid code duplication, when
# the whole triton flow with new k cache layout has been supported and tested.
# And tests for the alibi kernel using new kcache layout will be added then.
return
torch.manual_seed(123)
# It's necessary to clear cache here.
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
num_kv_heads = num_attn_heads // kv_group_num
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
max_seq_len = max_num_blocks_per_seq * block_size
dtype = torch.float16
device = get_current_device()
alibi_slopes = None
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(num_attn_heads, device)
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item()
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM)
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
q_unpad = q_unpad.contiguous()
if use_new_kcache_layout:
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
else:
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
k_cache_triton = torch.zeros_like(k_cache_ref)
v_cache_triton = torch.zeros_like(v_cache_ref)
_, num_heads, head_dim = q_unpad.shape
out_triton = context_attention_unpadded(
q_unpad,
k_unpad,
v_unpad,
k_cache_triton,
v_cache_triton,
context_lengths,
block_tables,
block_size,
alibi_slopes=alibi_slopes,
use_new_kcache_layout=use_new_kcache_layout,
)
out_triton = out_triton.view(-1, num_heads, head_dim)
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads, alibi_slopes)
assert out_torch.shape == out_triton.shape
assert torch.allclose(out_torch, out_triton, atol=1e-3)
assert torch.equal(k_cache_ref, k_cache_triton)
assert torch.equal(v_cache_ref, v_cache_triton)
if __name__ == "__main__":
test_context_attention(4, 32, 8, 16, 1, True, True, True)

View File

@@ -0,0 +1,197 @@
import numpy as np
import pytest
import torch
from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 128
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
def prepare_data(
bsz: int,
num_attn_heads: int,
num_kv_heads: int,
head_dim: int,
same_context_len: bool,
q_len: int,
max_kv_seq_len: int,
dtype=torch.float16,
device="cuda",
):
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
# otherwise generate random context lengths.
# returns
# q [bsz, num_attn_heads, q_len, head_dim]
# k_unpad/v_unpad [num_tokens, num_kv_heads, head_dim]
kv_lengths = (
torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
if same_context_len
else torch.randint(low=1, high=max_kv_seq_len, size=(bsz,), dtype=torch.int32, device=device)
)
num_tokens = torch.sum(kv_lengths).item()
q_size = (bsz, q_len, num_attn_heads, head_dim)
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
return q, k_unpad, v_unpad, kv_lengths
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("bsz", [7, 16])
@pytest.mark.parametrize("block_size", [16, 32])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16])
@pytest.mark.parametrize("num_attn_heads", [16])
@pytest.mark.parametrize("kv_group_num", [1, 4])
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("q_len", [1, 5])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_flash_decoding(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_attn_heads: int,
kv_group_num: int,
same_context_len: bool,
q_len: int,
use_alibi_slopes: bool,
use_new_kcache_layout: bool,
):
if use_new_kcache_layout and use_alibi_slopes:
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
# the code (alibi kernel) will be refactored later to avoid code duplication, when
# the whole triton flow with new k cache layout has been supported and tested.
# And tests for the alibi kernel using new kcache layout will be added then.
pytest.skip("Alibi kernel does not support new kcache layout yet.")
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
num_kv_heads = num_attn_heads // kv_group_num
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float16
device = get_current_device()
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(num_attn_heads, device)
# Currently, alibi flash decoding does not support q_len>1.
q_len = 1
else:
alibi_slopes = None
q, k_unpad, v_unpad, kv_lengths = prepare_data(
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device
)
# The maximum sequence length in the batch (if context lengths randomly generated)
max_kv_len_in_b = kv_lengths.max().item()
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, num_attn_heads, max_kv_len_in_b, q.device)
attention_mask = attention_mask + alibi_mask
if q_len == 1:
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
out_torch = torch_attn_ref(
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
)
if use_new_kcache_layout:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
else:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
# The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz * q_len, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
mid_output = torch.empty(
size=(bsz * q_len, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
mid_output_lse = torch.empty(
size=(bsz * q_len, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device
)
sm_scale = 1.0 / (HEAD_DIM**0.5)
# Here we use different methods to hide the q_len dimension,
# refer to attention forward function in modeling.
if q_len > 1:
q = q.transpose(1, 2).contiguous() # [bsz, q_len, num_heads, head_dim]
q = q.view(-1, q.size(-2), q.size(-1)) # [bsz * q_len, num_heads, head_dim]
else:
q = q.squeeze(2)
assert q.shape == (bsz * q_len, num_attn_heads, HEAD_DIM)
out_triton = flash_decoding_attention(
q,
k_cache,
v_cache,
kv_lengths,
block_tables,
block_size,
max_kv_len_in_b,
output,
mid_output,
mid_output_lse,
alibi_slopes=alibi_slopes,
sm_scale=sm_scale,
kv_group_num=kv_group_num,
q_len=q_len,
use_new_kcache_layout=use_new_kcache_layout,
) # [bsz * q_len, num_heads, head_dim]
assert out_torch.shape == out_triton.shape
rtol = 1e-4
# After the shape becomes larger, some data elements are too small, leading to excessively large relative errors.
if bsz >= 16 and use_alibi_slopes:
rtol = 100
numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol)
if __name__ == "__main__":
test_flash_decoding(16, 32, 32, 16, 1, True, 1, use_alibi_slopes=False, use_new_kcache_layout=True)

View File

@@ -0,0 +1,50 @@
from copy import deepcopy
import pytest
import torch
from packaging import version
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
from colossalai.kernel.triton.no_pad_rotary_embedding import rotary_embedding
from colossalai.kernel.triton.rotary_cache_copy import get_xine_cache
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
def test_fused_rotary_emb():
num_tokens = 20
num_kv_heads = 32
head_dim = 64
dtype = torch.float32
q_shape = (num_tokens, num_kv_heads, head_dim)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
q_copy = deepcopy(q)
k_shape = (num_tokens, num_kv_heads, head_dim)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
k_copy = deepcopy(k)
cos_shape = (1024, head_dim)
lengths = torch.tensor([3, 4, 6, 7], device="cuda")
cos_cache = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin_cache = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
cos, sin = get_xine_cache(lengths, cos_cache[:, : head_dim // 2], sin_cache[:, : head_dim // 2])
rotary_embedding(q, k, cos, sin)
fused_rotary_embedding(q_copy, k_copy, cos_cache, sin_cache, lengths)
torch.allclose(q, q_copy)
torch.allclose(k, k_copy)
if __name__ == "__main__":
test_fused_rotary_emb()

View File

@@ -0,0 +1,168 @@
import pytest
import torch
from packaging import version
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
mock_alloc_single_token,
)
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 32
def prepare_data(
bsz,
num_kv_heads,
head_dim,
block_size,
max_num_blocks_per_seq,
same_context_len,
max_seq_len,
n=1,
device="cuda",
dtype=torch.float16,
use_new_kcache_layout=False,
):
assert max_seq_len > n, "max_seq_len must be greater than n"
past_kv_seq_lengths = (
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
if same_context_len
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
)
num_tokens = torch.sum(past_kv_seq_lengths).item()
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
if use_new_kcache_layout:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
else:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
block_tables = block_tables.to(device=device)
new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
new_v = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
for _ in range(n):
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
past_kv_seq_lengths += 1
return new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("bsz", [7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [16])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("n_tokens", [1, 5])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_copy_kv_to_caches(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
n_tokens: int,
use_new_kcache_layout: bool,
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float16
device = get_current_device()
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
block_size,
max_num_blocks_per_seq,
same_context_len,
max_seq_len,
n_tokens,
device=device,
dtype=dtype,
use_new_kcache_layout=use_new_kcache_layout,
)
k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
k_cache_copy = k_cache.detach().clone()
past_kv_seq_lengths = kv_seq_lengths - n_tokens
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_lengths // block_size]
offsets_in_block = past_kv_seq_lengths % block_size
# Copy k (or v) to k (or v) cache
copy_k_to_blocked_cache(
new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens, use_new_kcache_layout=use_new_kcache_layout
)
# Reshape target k from k cache to compare if matching with original tensor
# Mainly to handle cases of n_tokens > 1
k_target = []
for i in range(bsz):
block_table = block_tables[i]
curr_kv_len = past_kv_seq_lengths[i].item()
offset = offsets_in_block[i].item()
tokens_left = n_tokens
while tokens_left > 0:
tokens_to_fill = min(block_size - offset, tokens_left)
curr_block_id = block_table[curr_kv_len // block_size]
if use_new_kcache_layout:
k_target.append(k_cache[curr_block_id, :, :, offset : offset + tokens_to_fill, :])
else:
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
curr_kv_len += tokens_to_fill
tokens_left -= tokens_to_fill
offset = 0
if use_new_kcache_layout:
k_target = torch.concat(k_target, dim=2).permute(2, 0, 1, 3).contiguous()
k_target = k_target.reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
else:
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
if n_tokens == 1:
# Copy k and v to k/v caches
k_cache = k_cache_copy
copy_kv_to_blocked_cache(
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables, use_new_kcache_layout=use_new_kcache_layout
)
if use_new_kcache_layout:
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
k_target = k_target.contiguous().reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
else:
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, True, n_tokens=1)

View File

@@ -0,0 +1,55 @@
import pytest
import torch
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
@parameterize("M", [2, 4, 8, 16])
@parameterize("N", [64, 128])
def test_layer_norm(M, N):
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_triton, _ = rms_layernorm(x, weight, eps=eps)
y_llama = rms_norm.forward(x).to(dtype)
assert y_triton.shape == y_llama.shape
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
y_triton, residual = rms_layernorm(x, weight, eps=eps, residual=residual)
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_triton.shape == y_llama.shape
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
if __name__ == "__main__":
test_layer_norm()

View File

@@ -0,0 +1,100 @@
import pytest
import torch
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.triton import decoding_fused_rotary_embedding
from tests.test_infer.test_kernels.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
)
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_rotary_emb(x, cos, sin):
seq_len, h, dim = x.shape
x0 = x[:, :, 0 : dim // 2]
x1 = x[:, :, dim // 2 : dim]
cos = cos.view((seq_len, 1, dim // 2))
sin = sin.view((seq_len, 1, dim // 2))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
return torch.cat((o0, o1), dim=-1)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("SEQ_LEN", [64])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
cos_2 = cos[:, :32]
sin_2 = sin[:, :32]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
assert torch.allclose(embd_x0, embd_stimulated_x)
# create data
block_size = 32
max_num_blocks_per_seq = 4
q_shape = (TOTAL_TOKENS, H, D)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (TOTAL_TOKENS, H, D)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k)
new_v = torch.randn_like(new_k)
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
v_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
if use_new_kcache_layout:
x = 16 // torch.tensor([], dtype=dtype).element_size()
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, D // x, block_size, x)
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
else:
k_cache = torch.zeros_like(v_cache)
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout
)
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
if __name__ == "__main__":
test_rotary_emb(4, 64, 32, 64, torch.float32, use_new_kcache_layout=True)

View File

@@ -0,0 +1,66 @@
import pytest
import torch
from packaging import version
from colossalai.kernel.triton import get_xine_cache
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""
if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
index_arrays = [(length - 1).view(-1) for length in lengths]
indices = torch.cat(index_arrays, dim=-1)
cos_output = cos_cache[indices].to(dtype=dtype)
sin_output = sin_cache[indices].to(dtype=dtype)
return (cos_output, sin_output)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
@pytest.mark.parametrize("HEAD_DIM", [64])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda")
# prefill
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
assert torch.allclose(cos, cos_ref)
assert torch.allclose(sin, sin_ref)
# decoding
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False)
assert torch.allclose(cos, ncos_ref)
assert torch.allclose(sin, nsin_ref)
if __name__ == "__main__":
test_get_xine_cache(4, 64, 256, torch.float32)

213
tests/test_infer/test_kvcache_manager.py Normal file → Executable file
View File

@@ -1,66 +1,179 @@
import os
import random
import pytest
import torch
from packaging import version
from transformers.models.llama import LlamaConfig
from colossalai.inference.kv_cache import MemoryManager
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
BATCH_SIZE = 4
INPUT_LEN = 16
OUTPUT_LEN = 8
LAYER_NUM = 4
HEAD_NUM = 32
HEAD_DIM = 128
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
@parameterize(
"test_config",
[
{
"elem_size": 2,
"block_size": 4,
}
],
)
def test_logical_blocks(test_config):
block = CacheBlock(block_id=0, block_size=test_config["block_size"], elem_size=test_config["elem_size"])
assert block.is_empty()
assert block.available_space == test_config["block_size"]
assert not block.has_ref()
block.add_ref()
assert block.ref_count == 1
assert block.has_ref()
block.remove_ref()
assert block.ref_count == 0
block.allocate(1)
assert block.allocated_size == 1
block.allocate(test_config["block_size"] - 1)
assert block.available_space < 1
@parameterize(
"test_config",
[
{
"hidden_size": 512,
"num_attention_heads": 16,
"num_layers": 2,
"block_size": 8,
"max_batch_size": 10,
"max_input_len": 32,
"max_output_len": 32,
"dtype": torch.float32,
"beam_width": 1,
"tp_size": 1,
},
{
"hidden_size": 128,
"num_attention_heads": 4,
"num_layers": 3,
"block_size": 4,
"max_batch_size": 4,
"max_input_len": 64,
"max_output_len": 32,
"dtype": torch.float16,
"beam_width": 3,
"tp_size": 1,
},
],
)
def check_cache_manager(test_config):
disable_existing_loggers()
size = batch_size * (input_len + output_len)
kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank)
key_buffers = kvcache_manager.key_buffer
value_buffers = kvcache_manager.value_buffer
assert len(key_buffers) == len(value_buffers) == layer_num
assert key_buffers[0].shape == value_buffers[0].shape
# required size exceeds the maximum allocated size
invalid_locs = kvcache_manager.alloc_contiguous(size + 1)
assert invalid_locs is None
# for prefill stage, allocation via alloc and alloc_contiguous should be the same
total_token_prefill = batch_size * input_len
prefill_locs = kvcache_manager.alloc(total_token_prefill)
kvcache_manager.free_all()
prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0]
assert torch.equal(prefill_locs, prefill_locs_contiguous)
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
kvcache_manager.alloc_contiguous(batch_size)
assert torch.all(kvcache_manager.mem_state[: total_token_prefill + batch_size] == False)
assert test_config["max_batch_size"] > 1
hidden_size = test_config.pop("hidden_size")
num_layers = test_config.pop("num_layers")
num_attention_heads = test_config.pop("num_attention_heads")
head_size = hidden_size // num_attention_heads
block_size = test_config["block_size"]
max_batch_size = test_config["max_batch_size"]
max_input_length = test_config["max_input_len"]
max_output_length = test_config["max_output_len"]
inference_config = InferenceConfig(**test_config)
model_config = LlamaConfig(
hidden_size=hidden_size,
num_hidden_layers=num_layers,
num_attention_heads=num_attention_heads,
)
cache_manager = KVCacheManager(inference_config, model_config)
num_blocks = cache_manager.total_num_blocks
assert num_blocks > 0
assert len(cache_manager._cache_blocks) == num_blocks
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
assert len(key_caches) == num_layers
expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size)
assert key_caches[0].shape == expected_kv_shape
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
expected_kv_block_shape = expected_kv_shape[1:]
assert k_cache_block0.shape == expected_kv_block_shape
assert v_cache_block0.shape == expected_kv_block_shape
max_blocks_per_seq = cache_manager.get_max_blocks_per_sequence()
block_tables = torch.tensor(
[[-1 for _ in range(max_blocks_per_seq)] for _ in range(test_config["max_batch_size"])], dtype=torch.int32
)
context_lengths = [random.randint(1, max_input_length) for _ in range(max_batch_size)]
cnt_blocks_used = 0
# Mock Prefill
for req_i in range(max_batch_size):
cur_seq_len = context_lengths[req_i]
cur_block_table = block_tables[req_i]
cache_manager.allocate_context_from_block_table(cur_block_table, cur_seq_len)
last_allocated_idx = (cur_seq_len - 1) // block_size
assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0)
cnt_blocks_used += torch.sum(cur_block_table >= 0).item()
assert cache_manager.num_available_blocks == num_blocks - cnt_blocks_used
# Mock Decoding
for req_i in range(max_batch_size):
context_length = context_lengths[req_i]
cur_output_length = random.randint(1, max_output_length)
cur_block_table = block_tables[req_i]
for _ in range(cur_output_length):
cache_manager.allocate_token_from_block_table(cur_block_table, context_length)
context_length += 1
context_length -= 1
last_allocated_idx = context_length // block_size
space_allocated_on_last_block = context_length % block_size + 1
assert space_allocated_on_last_block > 0
block_id = cur_block_table[last_allocated_idx]
block: CacheBlock = cache_manager._cache_blocks[block_id]
assert block.allocated_size == space_allocated_on_last_block
# Randomly select a request and clear its cache
req_i = random.randint(0, max_batch_size - 1)
context_length = context_lengths[req_i]
blocks_used_by_req = torch.sum(block_tables[req_i] >= 0).item()
prev_available_blocks = cache_manager.num_available_blocks
cache_manager.free_block_table(block_tables[req_i])
assert cache_manager.num_available_blocks == blocks_used_by_req + prev_available_blocks
k_ptr_block0_layer0, _ = cache_manager.get_block_kv_ptrs(0, 0)
k_ptr_block1_layer0, _ = cache_manager.get_block_kv_ptrs(1, 0)
elem_size = torch.tensor([], dtype=test_config["dtype"]).element_size()
expected_stride = block_size * num_attention_heads * head_size * elem_size
assert k_ptr_block1_layer0 - k_ptr_block0_layer0 == expected_stride
cache_manager.clear_all()
assert cache_manager.num_available_blocks == num_blocks
for cache_block in cache_manager._cache_blocks:
assert cache_block.available_space == block_size
# Mock batch operations (Prefill/Decoding updates)
context_lengths = torch.tensor([max_input_length, max_input_length - 1])
block_tables = torch.tensor(
[[-1 for _ in range(cache_manager.max_blocks_per_sequence)] for _ in range(2)], dtype=torch.int32
)
cache_manager.allocate_context_from_block_tables(block_tables, context_lengths)
cache_manager.allocate_tokens_from_block_tables(block_tables, context_lengths)
cache_manager.free_block_tables(block_tables)
for cache_block in cache_manager._cache_blocks:
assert cache_block.available_space == block_size
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_cache_manager()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager_dist():
spawn(
create_cache_manager,
4,
batch_size=BATCH_SIZE,
input_len=INPUT_LEN,
output_len=OUTPUT_LEN,
layer_num=LAYER_NUM,
head_num=HEAD_NUM,
head_dim=HEAD_DIM,
)
def test_cache_manager():
spawn(run_dist, 1)
if __name__ == "__main__":
test_cache_manager_dist()
test_logical_blocks()
test_cache_manager()

View File

@@ -0,0 +1,145 @@
import pytest
import torch
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
@pytest.mark.skip(reason="This test is not used in the current version.")
def test_copy_to_cache():
key = torch.ones((2, 11, 3, 3))
key[0, 9, :, :] = 0
key[1, -2:, :, :] = 0
cache = torch.zeros(8, 3, 8, 3)
block_tables = torch.tensor([[0, 1], [2, 3]])
lengths = torch.tensor([9, 8])
cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill")
assert cache[1, 0, 0, 0] == 1
assert cache[3, 0, 0, 0] == 0
decoding_key = torch.ones((2, 1, 3, 3))
cache = copy_to_cache(decoding_key, cache=cache, lengths=lengths + 1, block_tables=block_tables, type="decoding")
assert cache[1, 0, 0, 1] == 1
assert cache[3, 0, 0, 0] == 1
@pytest.mark.skip(reason="This test is not used in the current version.")
def test_convert_kvcache():
cache = torch.ones(8, 3, 8, 3)
key = torch.ones(2, 1, 3, 3) + 1
lengths = torch.tensor([10, 9])
block_tables = torch.tensor([[0, 1], [2, 3]])
copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="decoding")
converted_cache = convert_kvcache(cache=cache, lengths=lengths, block_tables=block_tables)
assert converted_cache.shape == (2, 10, 3, 3)
@pytest.mark.skip(reason="This test is not used in the current version.")
def test_context_attention():
"""
test config: head_num = 4, head_size = 4
"""
attn = PagedAttention()
q = k = v = torch.randn(8, 4, 4)
k_cache = torch.empty(8, 4, 8, 4)
v_cache = torch.empty(8, 4, 8, 4)
context_lengths = torch.tensor(
[
8,
]
)
block_tables = torch.tensor([[0, 1]])
attn.nopad_context_forward(q, k, v, k_cache, v_cache, context_lengths, block_tables)
# test padded q/k/v
pad_q = pad_k = pad_v = q.unsqueeze(0)
attn.pad_context_forward(pad_q, pad_k, pad_v, k_cache, v_cache, context_lengths, block_tables)
config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=16)
transformer_attn = LlamaAttention(config)
transformer_attn.training = False
# test accuracy with LlamaAttention
hidden_states = torch.randn(1, 8, 16)
proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device)
position_ids = position_ids.unsqueeze(0)
cos, sin = transformer_attn.rotary_emb(proj_v, 8)
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids)
pad_attn_output = attn.pad_context_forward(
proj_q.transpose(1, 2),
proj_k.transpose(1, 2),
proj_v.transpose(1, 2),
k_cache,
v_cache,
context_lengths,
block_tables,
)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
attn_mask = AttentionMaskConverter._make_causal_mask(
hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0
)
attn_mask += PagedAttention.generate_padding_mask(context_lengths, 8)
attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask)
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
@pytest.mark.skip(reason="This test is not used in the current version.")
def test_decoding_attention():
# test the pipeline of decoding attention
attn = PagedAttention()
q = k = v = torch.randn(2, 1, 4, 8)
k_cache = torch.empty(8, 4, 8, 8)
v_cache = torch.empty(8, 4, 8, 8)
past_kv = torch.randn(2, 8, 4, 8)
context_lenghths = torch.tensor([8, 8])
lengths = context_lenghths + 1
block_tables = torch.tensor([[0, 1], [2, 3]])
copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables)
copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables)
attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables)
# test decoding accuracy, past_kv is reused
config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=32)
transformer_attn = LlamaAttention(config)
transformer_attn.layer_idx = 0
transformer_attn.training = False
hidden_states = torch.randn(2, 1, 32)
proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
cos, sin = transformer_attn.rotary_emb(proj_v, 16)
position_ids = lengths - 1
position_ids = position_ids.unsqueeze(1) # NOTE: this may be wrong
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids, unsqueeze_dim=2)
llama_past_kv = DynamicCache()
llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0)
# past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim
pad_attn_output = attn.pad_decoding_forward(
proj_q.transpose(1, 2), proj_k.transpose(1, 2), proj_v.transpose(1, 2), k_cache, v_cache, lengths, block_tables
)
attn_mask = AttentionMaskConverter._make_causal_mask(q.shape[:2], q.dtype, q.device, past_key_values_length=8)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, 9).unsqueeze(1).unsqueeze(2)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
position_ids = context_lenghths.unsqueeze(1)
attn_output, _, _ = transformer_attn.forward(
hidden_states, past_key_value=llama_past_kv, position_ids=position_ids, attention_mask=attn_mask
)
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
if __name__ == "__main__":
test_copy_to_cache()
test_convert_kvcache()
test_context_attention()
test_decoding_attention()

View File

@@ -0,0 +1,138 @@
import os
import random
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda()
model = model.eval()
inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
]
output_len = 38
if do_sample:
top_p = 0.5
top_k = 50
else:
top_p = None
top_k = None
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
use_cuda_kernel=use_cuda_kernel,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs)
return result_list[0]
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
# NOTE(caidi) If do_sample is set to True or use_cuda_kernel is set to False, the inference result will be different from that of the transformer.
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
def check_tp_engine(prompt_template, do_sample, use_cuda_kernel):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": NoPaddingBaichuanModelInferPolicy(),
"use_cuda_kernel": use_cuda_kernel,
}
kwargs2 = {
"use_engine": False,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": None,
"use_cuda_kernel": use_cuda_kernel,
}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
@pytest.mark.skipif(
not os.path.exists(BAICHUAN_MODEL_NAME_OR_PATH),
reason="There is no local model address included, please replace this address with a valid one.",
)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_inference_engine():
check_tp_engine()
if __name__ == "__main__":
test_inference_engine()

View File

@@ -0,0 +1,105 @@
import pytest
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.request_handler import RequestHandler, RunningList
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_running_list():
"""
Test the RunningList Structure.
"""
running_list = RunningList(prefill_ratio=1.2)
seq1 = Sequence(
request_id=1,
prompt="abc",
input_token_id=[1, 2, 3],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
)
seq2 = Sequence(
request_id=2,
prompt="abc",
input_token_id=[1, 2, 3],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
)
running_list.append(seq1)
running_list.append(seq2)
assert running_list.ready_for_prefill()
assert len(running_list.decoding) == 0
assert len(running_list.prefill) > 0 and running_list.prefill[0] == seq1
seq = running_list.find_seq(seq1.request_id)
assert seq == seq1
running_list.mark_prefill_running()
for seq in running_list.prefill:
assert seq.status == RequestStatus.RUNNING
running_list.move_prefill_to_decoding([seq1.request_id, seq2.request_id])
assert len(running_list.prefill) == 0
assert len(running_list.decoding) > 0 and running_list.decoding[0] == seq1
running_list.remove(seq1)
running_list.remove(seq2)
assert running_list.is_empty()
def check_request_handler():
"""
Test main function of RequestHandler
"""
inference_config = InferenceConfig(
max_input_len=10,
max_output_len=10,
block_size=8,
)
model_config = LlamaConfig(
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
)
request_handler = RequestHandler(inference_config, model_config)
seq1 = Sequence(
request_id=1,
prompt="abc",
input_token_id=[1, 2, 3, 4, 5],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
)
request_handler.add_sequence(seq1)
# the priority should be 1
assert request_handler.waiting_list[1][0] == seq1
assert request_handler._has_waiting()
request_handler.abort_sequence(seq1.request_id)
assert not request_handler._has_waiting()
seq1.status = RequestStatus.WAITING
request_handler.add_sequence(seq1)
request_handler.schedule()
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_running_list()
check_request_handler()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_running_list_and_request_handler():
spawn(run_dist, 1)
if __name__ == "__main__":
test_running_list_and_request_handler()

View File

@@ -0,0 +1,105 @@
import random
import numpy as np
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.rpc_engine import RPCInferenceEngine
from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(tp_size, use_engine=False, prompt_template=None, do_sample=True, policy=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = "meta-llama/Llama-2-7b-hf" # remote mode path
inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,",
]
output_len = 38
top_p = 0.5
top_k = 50
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
dtype="fp32",
use_cuda_kernel=True,
tp_size=tp_size,
)
inference_engine = RPCInferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
model = AutoModelForCausalLM.from_pretrained(model).cuda()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
dtype="fp32",
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
def run_engine(tp_size, **kwargs):
return check_inference_engine(tp_size=tp_size, **kwargs)
@pytest.mark.largedist
@parameterize("prompt_template", [None, "llama"])
@parameterize("do_sample", [False])
@rerun_if_address_is_in_use()
def test_tp_engine(prompt_template, do_sample):
if torch.multiprocessing.get_start_method(allow_none=True) is None:
torch.multiprocessing.set_start_method("spawn")
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": NoPaddingLlamaModelInferPolicy(),
}
kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
test_tp_engine()