mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[Fix] resolve conflicts of rebasing feat/speculative-decoding (#5557)
- resolve conflicts of rebasing feat/speculative-decoding
This commit is contained in:
committed by
ocd_with_naming
parent
e1acb58423
commit
e60d430cf5
@@ -138,5 +138,6 @@ def test_flash_decoding(
|
||||
assert out_torch.shape == out_triton.shape
|
||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flash_decoding(16, 32, 32, 16, 1, True)
|
||||
|
@@ -2,7 +2,6 @@ import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.inference.modeling.layers.attention import copy_to_cache
|
||||
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
||||
@@ -28,8 +27,8 @@ def prepare_data(
|
||||
max_num_blocks_per_seq,
|
||||
same_context_len,
|
||||
max_seq_len,
|
||||
n,
|
||||
device,
|
||||
n=1,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
):
|
||||
assert max_seq_len > n, "max_seq_len must be greater than n"
|
||||
|
Reference in New Issue
Block a user