mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-21 22:53:34 +00:00
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
import os
|
|
|
|
import pytest
|
|
import torch
|
|
from packaging import version
|
|
|
|
from colossalai.inference.kv_cache import MemoryManager
|
|
from colossalai.logging import disable_existing_loggers
|
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|
|
|
BATCH_SIZE = 4
|
|
INPUT_LEN = 16
|
|
OUTPUT_LEN = 8
|
|
LAYER_NUM = 4
|
|
HEAD_NUM = 32
|
|
HEAD_DIM = 128
|
|
|
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
|
|
|
|
|
def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
|
|
os.environ["RANK"] = str(rank)
|
|
os.environ["LOCAL_RANK"] = str(rank)
|
|
os.environ["WORLD_SIZE"] = str(world_size)
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = str(port)
|
|
disable_existing_loggers()
|
|
|
|
size = batch_size * (input_len + output_len)
|
|
kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank)
|
|
key_buffers = kvcache_manager.key_buffer
|
|
value_buffers = kvcache_manager.value_buffer
|
|
assert len(key_buffers) == len(value_buffers) == layer_num
|
|
assert key_buffers[0].shape == value_buffers[0].shape
|
|
# required size exceeds the maximum allocated size
|
|
invalid_locs = kvcache_manager.alloc_contiguous(size + 1)
|
|
assert invalid_locs is None
|
|
# for prefill stage, allocation via alloc and alloc_contiguous should be the same
|
|
total_token_prefill = batch_size * input_len
|
|
prefill_locs = kvcache_manager.alloc(total_token_prefill)
|
|
kvcache_manager.free_all()
|
|
prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0]
|
|
assert torch.equal(prefill_locs, prefill_locs_contiguous)
|
|
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
|
|
kvcache_manager.alloc_contiguous(batch_size)
|
|
assert torch.all(kvcache_manager.mem_state[: total_token_prefill + batch_size] == False)
|
|
|
|
|
|
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
def test_cache_manager_dist():
|
|
spawn(
|
|
create_cache_manager,
|
|
4,
|
|
batch_size=BATCH_SIZE,
|
|
input_len=INPUT_LEN,
|
|
output_len=OUTPUT_LEN,
|
|
layer_num=LAYER_NUM,
|
|
head_num=HEAD_NUM,
|
|
head_dim=HEAD_DIM,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_cache_manager_dist()
|