[Fix] resolve conflicts of rebasing feat/speculative-decoding (#5557)

- resolve conflicts of rebasing feat/speculative-decoding
This commit is contained in:
Yuanheng Zhao
2024-04-07 14:53:30 +08:00
committed by ocd_with_naming
parent e1acb58423
commit e60d430cf5
6 changed files with 47 additions and 35 deletions

View File

@@ -138,5 +138,6 @@ def test_flash_decoding(
assert out_torch.shape == out_triton.shape
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
if __name__ == "__main__":
test_flash_decoding(16, 32, 32, 16, 1, True)

View File

@@ -2,7 +2,6 @@ import pytest
import torch
from packaging import version
from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
@@ -28,8 +27,8 @@ def prepare_data(
max_num_blocks_per_seq,
same_context_len,
max_seq_len,
n,
device,
n=1,
device="cuda",
dtype=torch.float16,
):
assert max_seq_len > n, "max_seq_len must be greater than n"