mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
from packaging import version
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.inference.tensor_parallel import MemoryManager
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
@@ -14,14 +15,15 @@ LAYER_NUM = 4
|
||||
HEAD_NUM = 32
|
||||
HEAD_DIM = 128
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(port)
|
||||
disable_existing_loggers()
|
||||
|
||||
size = batch_size * (input_len + output_len)
|
||||
@@ -41,21 +43,24 @@ def create_cache_manager(rank, world_size, port, batch_size, input_len, output_l
|
||||
assert torch.equal(prefill_locs, prefill_locs_contiguous)
|
||||
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
|
||||
kvcache_manager.alloc_contiguous(batch_size)
|
||||
assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False)
|
||||
assert torch.all(kvcache_manager.mem_state[: total_token_prefill + batch_size] == False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_cache_manager_dist():
|
||||
spawn(create_cache_manager,
|
||||
4,
|
||||
batch_size=BATCH_SIZE,
|
||||
input_len=INPUT_LEN,
|
||||
output_len=OUTPUT_LEN,
|
||||
layer_num=LAYER_NUM,
|
||||
head_num=HEAD_NUM,
|
||||
head_dim=HEAD_DIM)
|
||||
spawn(
|
||||
create_cache_manager,
|
||||
4,
|
||||
batch_size=BATCH_SIZE,
|
||||
input_len=INPUT_LEN,
|
||||
output_len=OUTPUT_LEN,
|
||||
layer_num=LAYER_NUM,
|
||||
head_num=HEAD_NUM,
|
||||
head_dim=HEAD_DIM,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_cache_manager_dist()
|
||||
|
Reference in New Issue
Block a user