[Hotfix] Fix accuracy and align attention method api with Triton kernel (#5229)

* fix accuracy

* alignment in attention

* fix attention

* fix

* fix bugs

* fix bugs

* fix bugs
This commit is contained in:
Jianghai
2024-01-08 15:56:00 +08:00
committed by FrankLeeeee
parent fa4fbdbffb
commit e545a871b8
6 changed files with 168 additions and 107 deletions

View File

@@ -3,7 +3,7 @@ import pytest
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.testing import spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_config_and_inference():
@@ -74,6 +74,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_config_and_inference():
spawn(run_dist, 1)

View File

@@ -11,7 +11,6 @@ from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

View File

@@ -8,7 +8,7 @@ import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize(
@@ -155,6 +155,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager():
spawn(run_dist, 1)

View File

@@ -3,15 +3,15 @@ import torch
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
import colossalai
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
from colossalai.testing import spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_copy_to_cache():
key = torch.ones((2, 10, 3, 3))
key = torch.ones((2, 11, 3, 3))
key[0, 9, :, :] = 0
key[1, -2:, :, :] = 0
cache = torch.zeros(8, 3, 3, 8)
@@ -32,7 +32,8 @@ def test_convert_kvcache():
key = torch.ones(2, 1, 3, 3) + 1
lengths = torch.tensor([10, 9])
block_tables = torch.tensor([[0, 1], [2, 3]])
converted_cache = convert_kvcache(key, cache=cache, lengths=lengths, block_tables=block_tables)
copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="decoding")
converted_cache = convert_kvcache(cache=cache, lengths=lengths, block_tables=block_tables)
assert converted_cache.shape == (2, 10, 3, 3)
@@ -40,7 +41,7 @@ def test_context_attention():
"""
test config: head_num = 4, head_size = 4
"""
attn = PagedAttention(4, 4)
attn = PagedAttention()
q = k = v = torch.randn(8, 4, 4)
k_cache = torch.empty(8, 4, 4, 8)
v_cache = torch.empty(8, 4, 4, 8)
@@ -61,48 +62,72 @@ def test_context_attention():
# test accuracy with LlamaAttention
hidden_states = torch.randn(1, 8, 16)
proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4)
proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4)
proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4)
pad_attn_output = attn.pad_context_forward(proj_q, proj_k, proj_v, k_cache, v_cache, context_lengths, block_tables)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device)
position_ids = position_ids.unsqueeze(0)
cos, sin = transformer_attn.rotary_emb(proj_v, 8)
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids)
pad_attn_output = attn.pad_context_forward(
proj_q.transpose(1, 2),
proj_k.transpose(1, 2),
proj_v.transpose(1, 2),
k_cache,
v_cache,
context_lengths,
block_tables,
)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
attn_mask = AttentionMaskConverter._make_causal_mask(
hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0
)
attn_mask += PagedAttention.generate_padding_mask(context_lengths, 8)
attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask)
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
def test_decoding_attention():
# test the pipeline of decoding attention
attn = PagedAttention(4, 4)
q = k = v = torch.randn(2, 1, 4, 4)
k_cache = torch.empty(8, 4, 4, 8)
v_cache = torch.empty(8, 4, 4, 8)
past_kv = torch.randn(2, 8, 4, 4)
attn = PagedAttention()
q = k = v = torch.randn(2, 1, 4, 8)
k_cache = torch.empty(8, 4, 8, 8)
v_cache = torch.empty(8, 4, 8, 8)
past_kv = torch.randn(2, 8, 4, 8)
context_lenghths = torch.tensor([8, 8])
lengths = context_lenghths + 1
block_tables = torch.tensor([[0, 1], [2, 3]])
copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables)
copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables)
attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables)
# test decoding accuracy, past_kv is reused
config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=16)
config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=32)
transformer_attn = LlamaAttention(config)
transformer_attn.layer_idx = 0
transformer_attn.training = False
hidden_states = torch.randn(2, 1, 16)
proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 4)
proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 4)
proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 4)
hidden_states = torch.randn(2, 1, 32)
proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
cos, sin = transformer_attn.rotary_emb(proj_v, 16)
position_ids = lengths - 1
position_ids = position_ids.unsqueeze(1) # NOTE: this may be wrong
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids, unsqueeze_dim=2)
llama_past_kv = DynamicCache()
llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0)
# past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim
pad_attn_output = attn.pad_decoding_forward(proj_q, proj_k, proj_v, k_cache, v_cache, lengths, block_tables)
attn_mask = AttentionMaskConverter._make_causal_mask(proj_q.shape[:2], q.dtype, q.device, past_key_values_length=8)
pad_attn_output = attn.pad_decoding_forward(
proj_q.transpose(1, 2), proj_k.transpose(1, 2), proj_v.transpose(1, 2), k_cache, v_cache, lengths, block_tables
)
attn_mask = AttentionMaskConverter._make_causal_mask(q.shape[:2], q.dtype, q.device, past_key_values_length=8)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, 9).unsqueeze(1).unsqueeze(2)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
position_ids = context_lenghths.unsqueeze(1)
attn_output, _, _ = transformer_attn.forward(
@@ -112,9 +137,9 @@ def test_decoding_attention():
def check_attention_layer():
# test_copy_to_cache()
# test_convert_kvcache()
# test_context_attention()
test_copy_to_cache()
test_convert_kvcache()
test_context_attention()
test_decoding_attention()
@@ -124,6 +149,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_attention_layer():
spawn(run_dist, 1)

View File

@@ -6,7 +6,7 @@ import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.request_handler import RequestHandler, RunningList
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.testing import spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_running_list():
@@ -78,6 +78,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_running_list_and_request_handler():
spawn(run_dist, 1)