mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-11 01:48:07 +00:00
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
* clean requirements * modify example inference struct * add test ci scripts * mark test_infer as submodule * rm deprecated cls & deps * import of HAS_FLASH_ATTN * prune inference tests to be run * prune triton kernel tests * increment pytest timeout mins * revert import path in openmoe
This commit is contained in:
parent
f9afe0addd
commit
55cc7f3df7
2
.github/workflows/build_on_pr.yml
vendored
2
.github/workflows/build_on_pr.yml
vendored
@ -91,7 +91,7 @@ jobs:
|
|||||||
container:
|
container:
|
||||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||||
timeout-minutes: 60
|
timeout-minutes: 75
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
|
@ -81,7 +81,7 @@ import colossalai
|
|||||||
from colossalai.inference import InferenceEngine, InferenceConfig
|
from colossalai.inference import InferenceEngine, InferenceConfig
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch()
|
||||||
|
|
||||||
# Step 1: create a model in "transformers" way
|
# Step 1: create a model in "transformers" way
|
||||||
model_path = "lmsys/vicuna-7b-v1.3"
|
model_path = "lmsys/vicuna-7b-v1.3"
|
||||||
|
@ -23,7 +23,7 @@ from colossalai.inference.core.engine import InferenceEngine, GenerationConfig
|
|||||||
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM, GlideLlamaConfig
|
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM, GlideLlamaConfig
|
||||||
|
|
||||||
# launch colossalai, setup distributed environment
|
# launch colossalai, setup distributed environment
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch()
|
||||||
|
|
||||||
# main model
|
# main model
|
||||||
model_path_or_name = "REPLACE_TO_VICUNA_7B_PATH_OR_MODEL_CARD"
|
model_path_or_name = "REPLACE_TO_VICUNA_7B_PATH_OR_MODEL_CARD"
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List
|
||||||
|
|
||||||
import torch
|
|
||||||
from ordered_set import OrderedSet
|
|
||||||
|
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
logger = get_dist_logger(__name__)
|
||||||
@ -170,242 +166,6 @@ class Sequence:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BatchInfo:
|
|
||||||
"""
|
|
||||||
Information to be passed and used for a batch of sequences.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_batch_size: int
|
|
||||||
kv_max_split_num: int
|
|
||||||
num_heads: int
|
|
||||||
head_dim: int
|
|
||||||
sequences_set: OrderedSet[Sequence] = None
|
|
||||||
is_prompts: bool = True
|
|
||||||
device: torch.device = None
|
|
||||||
dtype: torch.dtype = None
|
|
||||||
fd_inter_tensor: FDIntermTensors = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.device is None:
|
|
||||||
self.device = torch.cuda.current_device()
|
|
||||||
if self.sequences_set is None:
|
|
||||||
self.sequences_set = OrderedSet()
|
|
||||||
if self.fd_inter_tensor is None:
|
|
||||||
self.fd_inter_tensor = FDIntermTensors()
|
|
||||||
|
|
||||||
def init_fd_tensors(self):
|
|
||||||
if not self.fd_inter_tensor.is_initialized:
|
|
||||||
self.fd_inter_tensor.initialize(
|
|
||||||
max_batch_size=self.max_batch_size,
|
|
||||||
num_attn_heads=self.num_heads,
|
|
||||||
kv_max_split_num=self.kv_max_split_num,
|
|
||||||
head_dim=self.head_dim,
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_block_table_tensor(self) -> None:
|
|
||||||
tesnor_list = []
|
|
||||||
block_table = None
|
|
||||||
|
|
||||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
|
||||||
|
|
||||||
for seq in self.sequences_set:
|
|
||||||
block_table = seq.block_table
|
|
||||||
assert (
|
|
||||||
block_table is not None
|
|
||||||
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
|
||||||
tesnor_list.append(seq.block_table)
|
|
||||||
|
|
||||||
block_table = torch.stack(tesnor_list)
|
|
||||||
return block_table
|
|
||||||
|
|
||||||
def clear_batch(self) -> None:
|
|
||||||
"""
|
|
||||||
Clear sequence set and block table if we need to abort this batch.
|
|
||||||
Prefill: clear sequence set and move them to running batch(external)
|
|
||||||
Decoding: mark unfinished sequences as aborted.
|
|
||||||
"""
|
|
||||||
if self.is_prompts:
|
|
||||||
self.sequences_set.clear()
|
|
||||||
else:
|
|
||||||
for seq in self.sequences_set:
|
|
||||||
seq.mark_aborted()
|
|
||||||
if seq.check_finish():
|
|
||||||
seq.mark_finished()
|
|
||||||
|
|
||||||
self.sequences_set.clear()
|
|
||||||
|
|
||||||
def fliter_batch(self) -> List["Sequence"]:
|
|
||||||
"""
|
|
||||||
Remove completed sentences from a batch.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List["Sequence"]: List of finished sequences.
|
|
||||||
"""
|
|
||||||
finish_seqs = []
|
|
||||||
for seq in self.sequences_set:
|
|
||||||
if seq.check_finish():
|
|
||||||
finish_seqs.append(seq)
|
|
||||||
for finish_seq in finish_seqs:
|
|
||||||
self.sequences_set.discard(finish_seq)
|
|
||||||
return finish_seqs
|
|
||||||
|
|
||||||
def abort_seq(self, seq: "Sequence") -> "Sequence":
|
|
||||||
"""
|
|
||||||
Remove sequence from the batch.
|
|
||||||
"""
|
|
||||||
if not seq.check_finish():
|
|
||||||
seq.status = RequestStatus.ABORTED
|
|
||||||
self.sequences_set.discard(seq)
|
|
||||||
return seq
|
|
||||||
|
|
||||||
def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None:
|
|
||||||
"""
|
|
||||||
Add new sequence to batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seqs (List["Sequence"]): The list of new sequences.
|
|
||||||
"""
|
|
||||||
# covnert single sequence to list
|
|
||||||
if isinstance(seqs, Sequence):
|
|
||||||
seqs = [seqs]
|
|
||||||
|
|
||||||
for seq in seqs:
|
|
||||||
if seq in self.sequences_set:
|
|
||||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
|
||||||
continue
|
|
||||||
self.sequences_set.add(seq)
|
|
||||||
|
|
||||||
def del_seq(self, seq: Sequence) -> Sequence:
|
|
||||||
"""
|
|
||||||
Delete sequence in batch
|
|
||||||
"""
|
|
||||||
self.sequences_set.discard(seq)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_empty(self) -> None:
|
|
||||||
"""
|
|
||||||
Check whether sequences_set is empty.
|
|
||||||
"""
|
|
||||||
return not self.sequences_set
|
|
||||||
|
|
||||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
|
|
||||||
"""
|
|
||||||
Add an output token for each sentence in the batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens (List[int]): A batch of tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(tokens, torch.Tensor):
|
|
||||||
tokens = tokens.tolist()
|
|
||||||
|
|
||||||
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
|
||||||
|
|
||||||
for seq, token in zip(self.sequences_set, tokens):
|
|
||||||
if not isinstance(token, list):
|
|
||||||
if not isinstance(token, int):
|
|
||||||
raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
|
|
||||||
token = [token]
|
|
||||||
seq.output_token_id += token
|
|
||||||
seq.check_finish()
|
|
||||||
|
|
||||||
def get_batch_size(self) -> int:
|
|
||||||
"""
|
|
||||||
Get batch_size of this batch
|
|
||||||
"""
|
|
||||||
return len(self.sequences_set)
|
|
||||||
|
|
||||||
def get_batch_inputs(self) -> torch.LongTensor:
|
|
||||||
"""
|
|
||||||
Get bacth inputs for forward inference computation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_list = []
|
|
||||||
|
|
||||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
|
||||||
|
|
||||||
for seq in self.sequences_set:
|
|
||||||
if self.is_prompts:
|
|
||||||
if seq.output_len > 0:
|
|
||||||
input_list.append(seq.input_token_id + seq.output_token_id)
|
|
||||||
else:
|
|
||||||
input_list.append(seq.input_token_id)
|
|
||||||
else:
|
|
||||||
input_list.append([seq.output_token_id[-1]])
|
|
||||||
|
|
||||||
max_seq_len = max(len(sub_list) for sub_list in input_list)
|
|
||||||
|
|
||||||
# We assume that all the padding_id in seq are the same at present.
|
|
||||||
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
|
|
||||||
|
|
||||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Flattening the input tokens.
|
|
||||||
"""
|
|
||||||
input_list = []
|
|
||||||
|
|
||||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
|
||||||
|
|
||||||
for seq in self.sequences_set:
|
|
||||||
if self.is_prompts:
|
|
||||||
input_list.extend(seq.input_token_id)
|
|
||||||
else:
|
|
||||||
input_list.append(seq.output_token_id[-1])
|
|
||||||
|
|
||||||
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
|
||||||
|
|
||||||
def get_sequence_lengths(self):
|
|
||||||
"""
|
|
||||||
Get the input_len of each sentence in this batch.
|
|
||||||
"""
|
|
||||||
len_list = []
|
|
||||||
|
|
||||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
|
||||||
|
|
||||||
for seq in self.sequences_set:
|
|
||||||
len_list.append(seq.sentence_len)
|
|
||||||
|
|
||||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
|
||||||
|
|
||||||
def get_attn_mask(self) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Generate and return attention mask.
|
|
||||||
"""
|
|
||||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
|
||||||
|
|
||||||
past_values = []
|
|
||||||
# We assume that all the padding_id in seq are the same at present.
|
|
||||||
padding_id = self.sequences_set[0].pad_token_id
|
|
||||||
|
|
||||||
for seq in self.sequences_set:
|
|
||||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
|
||||||
|
|
||||||
max_seq_len = max(len(sub_list) for sub_list in past_values)
|
|
||||||
attn_mask = _make_tensor_with_pad(
|
|
||||||
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
return attn_mask.ne(padding_id).long()
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||||
assert len(x) <= max_len
|
assert len(x) <= max_len
|
||||||
return [pad] * (max_len - len(x)) + x
|
return [pad] * (max_len - len(x)) + x
|
||||||
|
|
||||||
|
|
||||||
def _make_tensor_with_pad(
|
|
||||||
x: Union[List[List[int]], List[int]],
|
|
||||||
max_len: int,
|
|
||||||
pad: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: Union[str, torch.device] = "cuda",
|
|
||||||
pin_memory: bool = False,
|
|
||||||
):
|
|
||||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
|
||||||
return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu")
|
|
||||||
|
0
examples/inference/benchmark_ops/test_ci.sh
Normal file
0
examples/inference/benchmark_ops/test_ci.sh
Normal file
@ -182,7 +182,7 @@ def benchmark_inference(args):
|
|||||||
|
|
||||||
|
|
||||||
def inference(rank, world_size, port, args):
|
def inference(rank, world_size, port, args):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
benchmark_inference(args)
|
benchmark_inference(args)
|
||||||
|
|
||||||
|
|
@ -17,7 +17,7 @@ def infer(args):
|
|||||||
# ==============================
|
# ==============================
|
||||||
# Launch colossalai, setup distributed environment
|
# Launch colossalai, setup distributed environment
|
||||||
# ==============================
|
# ==============================
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch()
|
||||||
coordinator = DistCoordinator()
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
@ -59,7 +59,7 @@ def infer(args):
|
|||||||
coordinator.print_on_master(out[0])
|
coordinator.print_on_master(out[0])
|
||||||
|
|
||||||
|
|
||||||
# colossalai run --nproc_per_node 1 llama_gen.py -m MODEL_PATH
|
# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# ==============================
|
# ==============================
|
||||||
# Parse Arguments
|
# Parse Arguments
|
4
examples/inference/llama/test_ci.sh
Normal file
4
examples/inference/llama/test_ci.sh
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
echo "Skip the test (this test is slow)"
|
||||||
|
|
||||||
|
# bash ./run_benchmark.sh
|
@ -35,7 +35,7 @@ from transformers.utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
|
||||||
from colossalai.kernel.extensions.pybind.flash_attention import HAS_FLASH_ATTN
|
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN
|
||||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||||
from colossalai.moe.layers import SparseMLP
|
from colossalai.moe.layers import SparseMLP
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
ordered_set
|
|
||||||
transformers==4.36.2
|
|
@ -1,6 +1,4 @@
|
|||||||
diffusers
|
diffusers
|
||||||
fbgemm-gpu==0.2.0
|
|
||||||
ordered_set
|
|
||||||
pytest
|
pytest
|
||||||
coverage==7.2.3
|
coverage==7.2.3
|
||||||
git+https://github.com/hpcaitech/pytest-testmon
|
git+https://github.com/hpcaitech/pytest-testmon
|
||||||
|
0
tests/test_infer/__init__.py
Normal file
0
tests/test_infer/__init__.py
Normal file
@ -2,7 +2,7 @@ import pytest
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
|
from colossalai.inference.struct import RequestStatus, Sequence
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
@ -20,27 +20,6 @@ def check_config_and_inference():
|
|||||||
max_output_len=256,
|
max_output_len=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence2 = Sequence(
|
|
||||||
request_id=2,
|
|
||||||
prompt="bcd",
|
|
||||||
input_token_id=[4, 5, 6],
|
|
||||||
block_size=16,
|
|
||||||
sample_params=None,
|
|
||||||
eos_token_id=2,
|
|
||||||
pad_token_id=2,
|
|
||||||
max_output_len=256,
|
|
||||||
)
|
|
||||||
|
|
||||||
sequence3 = Sequence(
|
|
||||||
request_id=3,
|
|
||||||
prompt="efg",
|
|
||||||
input_token_id=[7, 8, 9],
|
|
||||||
block_size=16,
|
|
||||||
sample_params=None,
|
|
||||||
eos_token_id=2,
|
|
||||||
pad_token_id=2,
|
|
||||||
max_output_len=256,
|
|
||||||
)
|
|
||||||
sequence.mark_running()
|
sequence.mark_running()
|
||||||
assert sequence.status == RequestStatus.RUNNING
|
assert sequence.status == RequestStatus.RUNNING
|
||||||
sequence.recycle()
|
sequence.recycle()
|
||||||
@ -51,33 +30,6 @@ def check_config_and_inference():
|
|||||||
assert sequence.output_len == 0
|
assert sequence.output_len == 0
|
||||||
assert sequence.check_finish() == False
|
assert sequence.check_finish() == False
|
||||||
|
|
||||||
batch = BatchInfo(
|
|
||||||
max_batch_size=8,
|
|
||||||
kv_max_split_num=16,
|
|
||||||
num_heads=2,
|
|
||||||
head_dim=128,
|
|
||||||
)
|
|
||||||
batch.add_seqs([sequence])
|
|
||||||
batch.add_seqs([sequence2, sequence3])
|
|
||||||
|
|
||||||
# add duplicated sequence to test that it will not be counted twice
|
|
||||||
batch.add_seqs([sequence])
|
|
||||||
|
|
||||||
assert batch.is_empty == False
|
|
||||||
assert batch.get_batch_size() == 3
|
|
||||||
batch.update_batch_tokens([1, 2, 3])
|
|
||||||
seq = batch.abort_seq(sequence)
|
|
||||||
seq2 = batch.fliter_batch()[0]
|
|
||||||
|
|
||||||
assert batch.get_batch_size() == 1
|
|
||||||
assert seq.output_len == 1
|
|
||||||
assert seq.output_token_id == [1]
|
|
||||||
assert seq2.output_len == 1
|
|
||||||
assert seq2.output_token_id == [2]
|
|
||||||
|
|
||||||
batch.clear_batch()
|
|
||||||
assert batch.is_empty == True
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
@ -86,7 +86,7 @@ def run_dist(rank, world_size, port):
|
|||||||
check_output_consistency(128)
|
check_output_consistency(128)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.largedist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_cuda_graph_infer():
|
def test_cuda_graph_infer():
|
||||||
spawn(run_dist, 1)
|
spawn(run_dist, 1)
|
||||||
|
@ -11,13 +11,16 @@ MAX_LEN = 100
|
|||||||
SPEC_NUM = 5
|
SPEC_NUM = 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tokenizer():
|
||||||
|
return AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
|
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
|
||||||
def test_drafter(spec_num: int):
|
def test_drafter(tokenizer, spec_num: int):
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
|
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
|
||||||
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
||||||
toy_config.pad_token_id = tokenizer.eos_token_id
|
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||||
drafter_model = LlamaForCausalLM(toy_config)
|
drafter_model = LlamaForCausalLM(toy_config)
|
||||||
@ -39,10 +42,9 @@ def test_drafter(spec_num: int):
|
|||||||
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
|
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
|
||||||
|
|
||||||
|
|
||||||
def test_spec_dec():
|
def test_spec_dec(tokenizer):
|
||||||
spec_num = SPEC_NUM
|
spec_num = SPEC_NUM
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
# Dummy config for Glide Model
|
# Dummy config for Glide Model
|
||||||
@ -67,5 +69,6 @@ def test_spec_dec():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_drafter(spec_num=SPEC_NUM)
|
dummy_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
test_spec_dec()
|
test_drafter(dummy_tokenizer, spec_num=SPEC_NUM)
|
||||||
|
test_spec_dec(dummy_tokenizer)
|
||||||
|
@ -165,8 +165,10 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
|||||||
func_to_run(**kwargs)
|
func_to_run(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
@parameterize("prompt_template", [None, "llama"])
|
@parameterize("prompt_template", [None, "llama"])
|
||||||
@parameterize("do_sample", [False])
|
@parameterize("do_sample", [False])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
def test_tp_engine(prompt_template, do_sample):
|
def test_tp_engine(prompt_template, do_sample):
|
||||||
kwargs1 = {
|
kwargs1 = {
|
||||||
"use_engine": True,
|
"use_engine": True,
|
||||||
@ -186,18 +188,14 @@ def test_tp_engine(prompt_template, do_sample):
|
|||||||
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
|
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
@parameterize("num_layers", [1])
|
@parameterize("num_layers", [1])
|
||||||
@parameterize("max_length", [64])
|
@parameterize("max_length", [64])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
def test_spec_dec(num_layers, max_length):
|
def test_spec_dec(num_layers, max_length):
|
||||||
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)
|
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
if __name__ == "__main__":
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_inference_engine():
|
|
||||||
test_tp_engine()
|
test_tp_engine()
|
||||||
test_spec_dec()
|
test_spec_dec()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_inference_engine()
|
|
||||||
|
@ -86,11 +86,11 @@ def torch_attn_unpad(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
@pytest.mark.parametrize("bsz", [7, 32])
|
||||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
@pytest.mark.parametrize("block_size", [16, 32])
|
||||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16])
|
||||||
@pytest.mark.parametrize("num_attn_heads", [16])
|
@pytest.mark.parametrize("num_attn_heads", [16])
|
||||||
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
@pytest.mark.parametrize("kv_group_num", [1, 4])
|
||||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||||
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
||||||
|
@ -68,11 +68,11 @@ def prepare_data(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
@pytest.mark.parametrize("bsz", [7, 16])
|
||||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
@pytest.mark.parametrize("block_size", [16, 32])
|
||||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16])
|
||||||
@pytest.mark.parametrize("num_attn_heads", [16])
|
@pytest.mark.parametrize("num_attn_heads", [16])
|
||||||
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
@pytest.mark.parametrize("kv_group_num", [1, 4])
|
||||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||||
@pytest.mark.parametrize("q_len", [1, 5])
|
@pytest.mark.parametrize("q_len", [1, 5])
|
||||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||||
@ -187,7 +187,7 @@ def test_flash_decoding(
|
|||||||
|
|
||||||
rtol = 1e-4
|
rtol = 1e-4
|
||||||
# After the shape becomes larger, some data elements are too small, leading to excessively large relative errors.
|
# After the shape becomes larger, some data elements are too small, leading to excessively large relative errors.
|
||||||
if bsz == 32 and use_alibi_slopes:
|
if bsz >= 16 and use_alibi_slopes:
|
||||||
rtol = 100
|
rtol = 100
|
||||||
|
|
||||||
numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol)
|
numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol)
|
||||||
|
@ -70,9 +70,9 @@ def prepare_data(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
@pytest.mark.parametrize("bsz", [7, 32])
|
||||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
||||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
@pytest.mark.parametrize("max_num_blocks_per_seq", [16])
|
||||||
@pytest.mark.parametrize("num_kv_heads", [16])
|
@pytest.mark.parametrize("num_kv_heads", [16])
|
||||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||||
@pytest.mark.parametrize("n_tokens", [1, 5])
|
@pytest.mark.parametrize("n_tokens", [1, 5])
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers.cache_utils import DynamicCache
|
from transformers.cache_utils import DynamicCache
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
@ -7,6 +8,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
|||||||
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||||
def test_copy_to_cache():
|
def test_copy_to_cache():
|
||||||
key = torch.ones((2, 11, 3, 3))
|
key = torch.ones((2, 11, 3, 3))
|
||||||
key[0, 9, :, :] = 0
|
key[0, 9, :, :] = 0
|
||||||
@ -24,6 +26,7 @@ def test_copy_to_cache():
|
|||||||
assert cache[3, 0, 0, 0] == 1
|
assert cache[3, 0, 0, 0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||||
def test_convert_kvcache():
|
def test_convert_kvcache():
|
||||||
cache = torch.ones(8, 3, 8, 3)
|
cache = torch.ones(8, 3, 8, 3)
|
||||||
key = torch.ones(2, 1, 3, 3) + 1
|
key = torch.ones(2, 1, 3, 3) + 1
|
||||||
@ -34,6 +37,7 @@ def test_convert_kvcache():
|
|||||||
assert converted_cache.shape == (2, 10, 3, 3)
|
assert converted_cache.shape == (2, 10, 3, 3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||||
def test_context_attention():
|
def test_context_attention():
|
||||||
"""
|
"""
|
||||||
test config: head_num = 4, head_size = 4
|
test config: head_num = 4, head_size = 4
|
||||||
@ -86,6 +90,7 @@ def test_context_attention():
|
|||||||
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
|
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||||
def test_decoding_attention():
|
def test_decoding_attention():
|
||||||
# test the pipeline of decoding attention
|
# test the pipeline of decoding attention
|
||||||
attn = PagedAttention()
|
attn = PagedAttention()
|
||||||
|
@ -128,7 +128,7 @@ def check_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
|||||||
not os.path.exists(BAICHUAN_MODEL_NAME_OR_PATH),
|
not os.path.exists(BAICHUAN_MODEL_NAME_OR_PATH),
|
||||||
reason="There is no local model address included, please replace this address with a valid one.",
|
reason="There is no local model address included, please replace this address with a valid one.",
|
||||||
)
|
)
|
||||||
@pytest.mark.dist
|
@pytest.mark.largedist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_inference_engine():
|
def test_inference_engine():
|
||||||
check_tp_engine()
|
check_tp_engine()
|
||||||
|
Loading…
Reference in New Issue
Block a user