mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[elixir] add elixir and its unit tests (#3835)
* [elixir] add elixir * [elixir] add unit tests * remove useless code * fix python 3.8 issue * fix typo * add test skip * add docstrings * add docstrings * add readme * fix typo
This commit is contained in:
0
tests/test_elixir/__init__.py
Normal file
0
tests/test_elixir/__init__.py
Normal file
89
tests/test_elixir/compatibility_check.py
Normal file
89
tests/test_elixir/compatibility_check.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.elixir import ElixirModule, ElixirOptimizer
|
||||
from colossalai.elixir.search import minimum_waste_search
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def check_elixir_compatibility(early_stop: bool = True):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
Args:
|
||||
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
|
||||
"""
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
# These models lead to CUDA error
|
||||
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
|
||||
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
|
||||
'torchaudio_wav2vec2_base', 'torchaudio_hubert_base', 'torchvision_convnext_base'):
|
||||
continue
|
||||
|
||||
try:
|
||||
print(name)
|
||||
global_size = dist.get_world_size()
|
||||
global_group = dist.GroupMember.WORLD
|
||||
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
|
||||
data = {
|
||||
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
sr = minimum_waste_search(
|
||||
# pre-commit: do not rearrange
|
||||
m=model,
|
||||
group_size=global_size,
|
||||
unified_dtype=torch.float16,
|
||||
prefetch=False,
|
||||
verbose=True)
|
||||
|
||||
model = ElixirModule(model, sr, global_group, prefetch=False, dtype=torch.float16)
|
||||
optimizer = ElixirOptimizer(model, optimizer, initial_scale=32)
|
||||
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
|
||||
optimizer.backward(loss)
|
||||
optimizer.step()
|
||||
passed_models.append(name)
|
||||
|
||||
del model, optimizer, criterion, data, output, loss
|
||||
except Exception as e:
|
||||
failed_info[name] = e
|
||||
if early_stop:
|
||||
raise e
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
|
||||
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
|
||||
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_elixir_compatibility(early_stop=early_stop)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def exam_compatibility(early_stop: bool = True):
|
||||
spawn(run_dist, 2, early_stop=early_stop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exam_compatibility(early_stop=False)
|
0
tests/test_elixir/test_chunk/__init__.py
Normal file
0
tests/test_elixir/test_chunk/__init__.py
Normal file
72
tests/test_elixir/test_chunk/fetcher_utils.py
Normal file
72
tests/test_elixir/test_chunk/fetcher_utils.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
|
||||
from colossalai.elixir.chunk.scheduler import FIFOScheduler
|
||||
from colossalai.elixir.hook import BufferStore, HookParam
|
||||
from colossalai.elixir.tensor import OutplaceTensor
|
||||
|
||||
|
||||
def to_divide(a: int, b: int):
|
||||
return a + (-a % b)
|
||||
|
||||
|
||||
def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
empty_grad.storage().resize_(0)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunk = fetcher.get_one_chunk(param)
|
||||
if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD:
|
||||
raise RuntimeError()
|
||||
fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(param, grad)
|
||||
fetcher.reduce_chunk(chunk)
|
||||
|
||||
return empty_grad
|
||||
|
||||
|
||||
def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo):
|
||||
pg_size = dist.get_world_size(process_group)
|
||||
|
||||
private_list = list()
|
||||
for param in model.parameters():
|
||||
block_size = to_divide(param.numel(), pg_size)
|
||||
private_list.append(BlockRequire(block_size, param.dtype))
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(private_block_list=private_list)
|
||||
cg = ChunkGroup(rcache=mp)
|
||||
# allocate chunk group
|
||||
fused_config = dict(rcache_fused=True)
|
||||
for param in model.parameters():
|
||||
cg.allocate_chunk([param], to_divide(param.numel(), pg_size), param.dtype, process_group, fused_config)
|
||||
# initialize chunk fetcher
|
||||
scheduler = FIFOScheduler()
|
||||
fetcher = ChunkFetcher(scheduler, cg)
|
||||
buffer = BufferStore(0, torch.float32)
|
||||
# register fetcher and gradient handler
|
||||
HookParam.attach_fetcher(fetcher, buffer)
|
||||
for param in model.parameters():
|
||||
param.register_hook(partial(grad_handler, param=param, fetcher=fetcher))
|
||||
param.__class__ = HookParam
|
||||
# set inplace to False for all modules
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'inplace'):
|
||||
module.inplace = False
|
||||
|
||||
def transform_input(self_module, inputs):
|
||||
fetcher.reset()
|
||||
input_list = list()
|
||||
for t in inputs:
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = OutplaceTensor(t)
|
||||
input_list.append(t)
|
||||
return tuple(input_list)
|
||||
|
||||
model.register_forward_pre_hook(transform_input)
|
||||
|
||||
return model, cg
|
63
tests/test_elixir/test_chunk/test_block.py
Normal file
63
tests/test_elixir/test_chunk/test_block.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_block():
|
||||
b = PublicBlock(123, torch.float16, 'cuda')
|
||||
payload_b = b.payload
|
||||
|
||||
assert payload_b.numel() == 123
|
||||
assert payload_b.dtype == torch.float16
|
||||
assert payload_b.device.type == 'cuda'
|
||||
assert payload_b.numel() * payload_b.element_size() == b.memo_occ
|
||||
|
||||
c = PrivateBlock(77, torch.float, 'cpu')
|
||||
payload_c = c.payload
|
||||
|
||||
assert payload_c.numel() == 77
|
||||
assert payload_c.dtype == torch.float
|
||||
assert payload_c.device.type == 'cpu'
|
||||
assert payload_c.numel() * payload_c.element_size() == c.memo_occ
|
||||
|
||||
print('test_block: ok')
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_memory_pool():
|
||||
mp = MemoryPool(device_type='cuda')
|
||||
private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)]
|
||||
mp.allocate(public_block_number=4, private_block_list=private_list)
|
||||
|
||||
block0 = mp.get_public_block()
|
||||
|
||||
assert block0 in mp.public_used_blocks
|
||||
assert mp.public_used_cnt == 1
|
||||
assert mp.public_free_cnt == 3
|
||||
|
||||
block1 = mp.get_public_block()
|
||||
|
||||
assert block1 in mp.public_used_blocks
|
||||
assert mp.public_used_cnt == 2
|
||||
assert mp.public_free_cnt == 2
|
||||
|
||||
mp.free_public_block(block0)
|
||||
mp.free_public_block(block1)
|
||||
|
||||
assert block0 in mp.public_free_blocks
|
||||
assert block1 in mp.public_free_blocks
|
||||
assert mp.public_used_cnt == 0
|
||||
assert mp.public_free_cnt == 4
|
||||
|
||||
block0 = mp.get_private_block(5, torch.float)
|
||||
assert block0.numel == 5
|
||||
assert block0.dtype == torch.float
|
||||
|
||||
print('test_memory_pool: ok')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_block()
|
||||
test_memory_pool()
|
155
tests/test_elixir/test_chunk/test_chunk.py
Normal file
155
tests/test_elixir/test_chunk/test_chunk.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
def exam_chunk_functions(nproc, group):
|
||||
a = torch.randn(2, 64, device='cuda')
|
||||
copy_a = a.clone()
|
||||
b = torch.randn(2, 2, 128, device='cuda')
|
||||
copy_b = b.clone()
|
||||
c = torch.randn(128, device='cuda')
|
||||
copy_c = c.clone()
|
||||
d = torch.randn(4, 32, device='cuda')
|
||||
copy_d = d.clone()
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_block_number=1)
|
||||
|
||||
chunk = Chunk(mp, 1024, torch.float, group)
|
||||
chunk.l2_norm_flag = True
|
||||
assert chunk.chunk_size == 1024
|
||||
assert chunk.chunk_dtype == torch.float
|
||||
assert chunk.shard_size == 1024 // nproc
|
||||
|
||||
def check_tensors():
|
||||
assert torch.equal(a, copy_a)
|
||||
assert torch.equal(b, copy_b)
|
||||
assert torch.equal(c, copy_c)
|
||||
assert torch.equal(d, copy_d)
|
||||
|
||||
chunk.append_tensor(a)
|
||||
chunk.append_tensor(b)
|
||||
chunk.append_tensor(c)
|
||||
chunk.append_tensor(d)
|
||||
check_tensors()
|
||||
|
||||
chunk.close_chunk()
|
||||
assert chunk.is_replica is False
|
||||
# check function: get_cpu_copy
|
||||
cpu_copys = chunk.get_cpu_copy()
|
||||
for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys):
|
||||
assert t_cpu.device.type == 'cpu'
|
||||
assert torch.equal(t_gpu.cpu(), t_cpu)
|
||||
# check function: access_chunk
|
||||
block = mp.get_public_block()
|
||||
chunk.access_chunk(block)
|
||||
assert chunk.is_replica
|
||||
assert chunk.scatter_check
|
||||
check_tensors()
|
||||
# check function: release_chunk
|
||||
chunk.optim_sync_flag = False
|
||||
block = chunk.release_chunk()
|
||||
assert block in mp.public_used_blocks
|
||||
assert chunk.is_replica is False
|
||||
assert chunk.optim_sync_flag is True
|
||||
# check function: access_chunk after release_chunk
|
||||
chunk.access_chunk(block)
|
||||
check_tensors()
|
||||
# check function: reduce_chunk
|
||||
norm = block.payload.float().norm(2)**2
|
||||
chunk.reduce_chunk()
|
||||
assert chunk.is_replica is False
|
||||
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
||||
|
||||
test_norm = torch.Tensor([chunk.l2_norm]).cuda()
|
||||
dist.all_reduce(test_norm)
|
||||
assert torch.allclose(norm, test_norm)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
print('chunk functions are ok')
|
||||
|
||||
|
||||
def exam_chunk_states(nproc, group):
|
||||
a = torch.randn(2, 64, device='cuda')
|
||||
copy_a = a.clone()
|
||||
b = torch.randn(2, 2, 128, device='cuda')
|
||||
copy_b = b.clone()
|
||||
c = torch.randn(128, device='cuda')
|
||||
copy_c = c.clone()
|
||||
d = torch.randn(4, 32, device='cuda')
|
||||
copy_d = d.clone()
|
||||
|
||||
private = [BlockRequire(1024, torch.float)]
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(private_block_list=private)
|
||||
|
||||
chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True)
|
||||
assert chunk.chunk_size == 1024
|
||||
assert chunk.chunk_dtype == torch.float
|
||||
assert chunk.shard_size == 1024 // nproc
|
||||
|
||||
def check_tensors():
|
||||
assert torch.equal(a, copy_a)
|
||||
assert torch.equal(b, copy_b)
|
||||
assert torch.equal(c, copy_c)
|
||||
assert torch.equal(d, copy_d)
|
||||
|
||||
chunk.append_tensor(a)
|
||||
chunk.append_tensor(b)
|
||||
chunk.append_tensor(c)
|
||||
chunk.append_tensor(d)
|
||||
check_tensors()
|
||||
|
||||
chunk.close_chunk()
|
||||
assert chunk.is_replica is False
|
||||
|
||||
chunk.access_chunk()
|
||||
assert chunk.is_replica
|
||||
check_tensors()
|
||||
|
||||
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
||||
chunk.tensor_trans_state(a, TensorState.COMPUTE)
|
||||
assert chunk.tensor_state_cnter[TensorState.HOLD] == 3
|
||||
assert chunk.tensor_state_cnter[TensorState.COMPUTE] == 1
|
||||
|
||||
tensor_list = [a, b, c, d]
|
||||
for t in tensor_list:
|
||||
chunk.tensor_trans_state(t, TensorState.COMPUTE)
|
||||
chunk.tensor_trans_state(t, TensorState.HOLD_AFTER_BWD)
|
||||
chunk.tensor_trans_state(t, TensorState.READY_FOR_REDUCE)
|
||||
assert chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
|
||||
assert chunk.reduce_check
|
||||
|
||||
torch.cuda.synchronize()
|
||||
print('chunk states are ok')
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_functions(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_functions(world_size=4)
|
71
tests/test_elixir/test_chunk/test_fetcher.py
Normal file
71
tests/test_elixir/test_chunk/test_fetcher.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.chunk import ChunkGroup
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.test_chunk.fetcher_utils import hook_transform
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def check_gradient(ddp_model, my_model, cg: ChunkGroup):
|
||||
for chunk in cg.fused_chunks:
|
||||
cg.access_chunk(chunk)
|
||||
|
||||
for (name, p0), p1 in zip(ddp_model.named_parameters(), my_model.parameters()):
|
||||
torch.cuda.synchronize()
|
||||
print(f'checking parameter {name}')
|
||||
assert_close(p0.grad.data, p1.data)
|
||||
|
||||
|
||||
def exam_chunk_fetcher(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
torch_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(torch_model)
|
||||
|
||||
rank = dist.get_rank(group)
|
||||
# get different data
|
||||
seed_all(1001 + rank)
|
||||
data = to_cuda(data_fn())
|
||||
|
||||
seed_all(1001, cuda_deterministic=True)
|
||||
ddp_model = DDP(torch_model)
|
||||
ddp_loss = ddp_model(**data)
|
||||
ddp_loss.backward()
|
||||
|
||||
hook_model, cg = hook_transform(test_model, group)
|
||||
my_loss = hook_model(**data)
|
||||
my_loss.backward()
|
||||
|
||||
assert_close(ddp_loss, my_loss)
|
||||
check_gradient(ddp_model, hook_model, cg)
|
||||
print('private chunk fetcher is ok')
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_fetcher(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_fetcher(world_size=2)
|
98
tests/test_elixir/test_chunk/test_group.py
Normal file
98
tests/test_elixir/test_chunk/test_group.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
def exam_chunk_group_functions(nproc, group):
|
||||
a = torch.randn(3, 64, device='cuda')
|
||||
copy_a = a.clone()
|
||||
b = torch.randn(2, 32, device='cuda')
|
||||
copy_b = b.clone()
|
||||
c = torch.randn(256, device='cuda')
|
||||
copy_c = c.clone()
|
||||
d = torch.randn(2, 2, 64, device='cuda')
|
||||
copy_d = d.clone()
|
||||
e = torch.randn(2, 33, device='cuda')
|
||||
copy_e = e.clone()
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)])
|
||||
cg = ChunkGroup(rcache=mp)
|
||||
c0 = cg.allocate_chunk([a, b], 256, torch.float, group)
|
||||
c1 = cg.allocate_chunk([c], 256, torch.float, group)
|
||||
c2 = cg.allocate_chunk([d], 256, torch.float, group)
|
||||
|
||||
fused_config = dict(rcache_fused=True)
|
||||
c3 = cg.allocate_chunk([e], 68, torch.float, group, fused_config)
|
||||
|
||||
def check_chunk_0():
|
||||
assert torch.equal(a, copy_a)
|
||||
assert torch.equal(b, copy_b)
|
||||
|
||||
def check_chunk_1():
|
||||
assert torch.equal(c, copy_c)
|
||||
|
||||
def check_chunk_2():
|
||||
assert torch.equal(d, copy_d)
|
||||
|
||||
def check_chunk_3():
|
||||
assert torch.equal(e, copy_e)
|
||||
|
||||
# check tensors_to_chunks
|
||||
chunks = cg.tensors_to_chunks([e, a])
|
||||
assert chunks[0] == c0
|
||||
assert chunks[1] == c3
|
||||
# check access_chunk for unfused chunks
|
||||
cg.access_chunk(c0)
|
||||
cg.access_chunk(c1)
|
||||
check_chunk_0()
|
||||
check_chunk_1()
|
||||
assert not cg.rcache_enough_check(c2)
|
||||
assert cg.rcache_enough_check(c3)
|
||||
# check access_chunk for fused chunks
|
||||
cg.access_chunk(c3)
|
||||
check_chunk_3()
|
||||
# check release_chunk for unfused chunks
|
||||
cg.release_chunk(c1)
|
||||
assert cg.rcache_enough_check(c2)
|
||||
# check access_chunk
|
||||
cg.access_chunk(c2)
|
||||
check_chunk_2()
|
||||
|
||||
cg.tensor_trans_state(e, TensorState.COMPUTE)
|
||||
cg.tensor_trans_state(e, TensorState.HOLD_AFTER_BWD)
|
||||
cg.tensor_trans_state(e, TensorState.READY_FOR_REDUCE)
|
||||
cg.reduce_chunk(c3)
|
||||
assert not c3.is_replica
|
||||
|
||||
torch.cuda.synchronize()
|
||||
print('chunk group functions are ok')
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_group(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_group(world_size=2)
|
130
tests/test_elixir/test_chunk/test_scheduler.py
Normal file
130
tests/test_elixir/test_chunk/test_scheduler.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.elixir.chunk import Chunk, MemoryPool
|
||||
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
def exam_fifo(nproc, group):
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_block_number=1)
|
||||
c0 = Chunk(mp, 1024, torch.float, group)
|
||||
c1 = Chunk(mp, 1024, torch.float, group)
|
||||
c2 = Chunk(mp, 1024, torch.float, group)
|
||||
|
||||
sdl = FIFOScheduler()
|
||||
sdl.reset()
|
||||
|
||||
sdl.add(c0)
|
||||
sdl.add(c1)
|
||||
sdl.add(c2)
|
||||
sdl.add(c0) # nothing happens here
|
||||
assert sdl.top() == c0
|
||||
|
||||
sdl.remove(c0)
|
||||
assert sdl.top() == c1, f'{sdl.top()}'
|
||||
sdl.remove(c0)
|
||||
assert sdl.top() == c1, f'{sdl.top()}'
|
||||
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c1
|
||||
sdl.remove(c1)
|
||||
assert sdl.top() == c2
|
||||
sdl.remove(c2)
|
||||
assert sdl.top() == c0
|
||||
|
||||
|
||||
def exam_prefetch(nproc, group):
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate()
|
||||
c0 = Chunk(mp, 1024, torch.float, group)
|
||||
c1 = Chunk(mp, 1024, torch.float, group)
|
||||
c2 = Chunk(mp, 1024, torch.float, group)
|
||||
|
||||
chunk_called_per_step = [[c0], [c1], [c2], [c0], [c0], [c1], [c2], [c2], [c1], [c0]]
|
||||
|
||||
sdl = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step)
|
||||
print(sdl.next_step_dict)
|
||||
sdl.reset()
|
||||
|
||||
sdl.step()
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c0
|
||||
|
||||
sdl.step()
|
||||
sdl.add(c1)
|
||||
assert sdl.top() == c1
|
||||
|
||||
sdl.step()
|
||||
sdl.add(c2)
|
||||
assert sdl.top() == c2
|
||||
|
||||
sdl.remove(c0)
|
||||
sdl.step()
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c2
|
||||
|
||||
sdl.remove(c0)
|
||||
sdl.step()
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c0
|
||||
sdl.remove(c0) # notice here
|
||||
|
||||
sdl.remove(c1)
|
||||
sdl.step()
|
||||
sdl.add(c1)
|
||||
assert sdl.top() == c1
|
||||
|
||||
sdl.remove(c2)
|
||||
sdl.step()
|
||||
sdl.add(c2)
|
||||
assert sdl.top() == c1
|
||||
|
||||
sdl.remove(c2)
|
||||
sdl.step()
|
||||
sdl.add(c2)
|
||||
assert sdl.top() == c2
|
||||
sdl.remove(c2) # notice here
|
||||
sdl.add(c0) # notice here
|
||||
|
||||
sdl.remove(c1)
|
||||
sdl.step()
|
||||
sdl.add(c1)
|
||||
assert sdl.top() == c1
|
||||
sdl.remove(c1) # notice here
|
||||
|
||||
sdl.remove(c0)
|
||||
sdl.step()
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c0
|
||||
|
||||
sdl.remove(c0)
|
||||
sdl.clear()
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_scheduler(world_size=1):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_scheduler()
|
22
tests/test_elixir/test_ctx/test_meta_ctx.py
Normal file
22
tests/test_elixir/test_ctx/test_meta_ctx.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from colossalai.elixir.ctx import MetaContext
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_meta_context():
|
||||
builder, *_ = TEST_MODELS.get('resnet')
|
||||
with MetaContext():
|
||||
model = builder()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
assert param.device.type == 'meta'
|
||||
print(name, param)
|
||||
|
||||
for name, buffer in model.named_buffers():
|
||||
assert buffer.device.type == 'meta'
|
||||
print(name, buffer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_meta_context()
|
56
tests/test_elixir/test_hook.py
Normal file
56
tests/test_elixir/test_hook.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.elixir.hook import BufferStore, HookParam
|
||||
from colossalai.elixir.tensor import FakeTensor
|
||||
|
||||
|
||||
def test_hook():
|
||||
x = nn.Parameter(torch.randn(4, 4))
|
||||
|
||||
ori_numel = x.numel()
|
||||
ori_size = x.size()
|
||||
ori_stride = x.stride()
|
||||
ori_offset = x.storage_offset()
|
||||
|
||||
fake_data = FakeTensor(x.data)
|
||||
x.data = fake_data
|
||||
x.__class__ = HookParam
|
||||
|
||||
assert x.numel() == ori_numel
|
||||
assert x.size() == ori_size
|
||||
assert x.stride() == ori_stride
|
||||
assert x.storage_offset() == ori_offset
|
||||
|
||||
|
||||
def test_store():
|
||||
buffer = BufferStore(1024, torch.float16)
|
||||
print(buffer)
|
||||
|
||||
x = torch.randn(4, 128, dtype=torch.float16, device='cuda')
|
||||
original_ptr_x = x.data_ptr()
|
||||
copy_x = deepcopy(x)
|
||||
|
||||
y = torch.randn(512, dtype=torch.float16, device='cuda')
|
||||
original_ptr_y = y.data_ptr()
|
||||
copy_y = deepcopy(y)
|
||||
|
||||
offset = 0
|
||||
offset = buffer.insert(x, offset)
|
||||
assert offset == x.numel()
|
||||
assert torch.equal(x, copy_x)
|
||||
|
||||
offset = buffer.insert(y, offset)
|
||||
assert offset == 1024
|
||||
assert torch.equal(y, copy_y)
|
||||
|
||||
buffer.erase(x)
|
||||
buffer.erase(y)
|
||||
assert x.data_ptr() == original_ptr_x
|
||||
assert y.data_ptr() == original_ptr_y
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_store()
|
35
tests/test_elixir/test_kernels/test_attn.py
Normal file
35
tests/test_elixir/test_kernels/test_attn.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
from torch.testing import assert_close
|
||||
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def exam_one_model(model_fn, data_fn):
|
||||
from colossalai.elixir.kernels.attn_wrapper import wrap_attention
|
||||
|
||||
torch_model = model_fn().cuda()
|
||||
test_model = deepcopy(torch_model)
|
||||
test_model = wrap_attention(test_model)
|
||||
|
||||
data = to_cuda(data_fn())
|
||||
torch_out = torch_model(**data)
|
||||
torch_out.backward()
|
||||
|
||||
test_out = test_model(**data)
|
||||
test_out.backward()
|
||||
|
||||
assert_close(torch_out, test_out)
|
||||
for (name, p_torch), p_test in zip(torch_model.named_parameters(), test_model.parameters()):
|
||||
assert_close(p_torch.grad, p_test.grad)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need to install xformers")
|
||||
def test_gpt_atten_kernel():
|
||||
exam_one_model(*TEST_MODELS.get('gpt2_micro'))
|
||||
exam_one_model(*TEST_MODELS.get('opt_micro'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt_atten_kernel()
|
58
tests/test_elixir/test_kernels/test_ln.py
Normal file
58
tests/test_elixir/test_kernels/test_ln.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.elixir.wrapper import ElixirModule
|
||||
|
||||
|
||||
def exam_fused_layernorm(nproc, group):
|
||||
torch_model = nn.LayerNorm(2048)
|
||||
fused_model = deepcopy(torch_model)
|
||||
|
||||
torch_model = torch_model.cuda()
|
||||
sr = simple_search(fused_model, nproc, 1, 1.0, verbose=True)
|
||||
fused_model = ElixirModule(fused_model, sr, group, use_fused_kernels=True)
|
||||
|
||||
data = torch.randn(2, 2048, device='cuda')
|
||||
|
||||
torch_loss = torch_model(data).sum()
|
||||
torch_loss.backward()
|
||||
|
||||
fused_loss = fused_model(data).sum()
|
||||
fused_model.backward(fused_loss)
|
||||
|
||||
assert_close(torch_loss, fused_loss)
|
||||
|
||||
grad_state = fused_model.state_dict(from_param=True)
|
||||
for name, param in torch_model.named_parameters():
|
||||
assert_close(param.grad.cpu(), grad_state[name])
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_fused_layernorm(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1])
|
||||
@pytest.mark.skip(reason='need to install apex')
|
||||
def test_fused_layernorm(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fused_layernorm(world_size=1)
|
38
tests/test_elixir/test_search/test_mini_waste.py
Normal file
38
tests/test_elixir/test_search/test_mini_waste.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import minimum_waste_search
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
def step_fn(model, inp):
|
||||
model(**inp).backward()
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_mini_waste_search():
|
||||
model_fn, data_fn = TEST_MODELS.get('gpt2_small')
|
||||
model = model_fn()
|
||||
data = data_fn()
|
||||
|
||||
sr = minimum_waste_search(model,
|
||||
1,
|
||||
unified_dtype=torch.float16,
|
||||
cpu_offload=True,
|
||||
prefetch=True,
|
||||
verbose=True,
|
||||
inp=data,
|
||||
step_fn=step_fn)
|
||||
|
||||
chunk_plans = deepcopy(sr.param_chunk_plans)
|
||||
for plan in chunk_plans:
|
||||
assert plan.chunk_dtype == torch.float16
|
||||
assert plan.kwargs.get('shard_device') == torch.device('cpu')
|
||||
assert plan.kwargs.get('cpu_pin_memory') == True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mini_waste_search()
|
30
tests/test_elixir/test_search/test_optimal.py
Normal file
30
tests/test_elixir/test_search/test_optimal.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import optimal_search
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
def step_fn(model, inp):
|
||||
model(**inp).backward()
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_optimal_search():
|
||||
model_fn, data_fn = TEST_MODELS.get('gpt2_small')
|
||||
model = model_fn()
|
||||
data = data_fn()
|
||||
|
||||
sr = optimal_search(model, 1, unified_dtype=torch.float16, overlap=True, verbose=True, inp=data, step_fn=step_fn)
|
||||
|
||||
chunk_plans = deepcopy(sr.param_chunk_plans)
|
||||
for plan in chunk_plans:
|
||||
assert plan.chunk_dtype == torch.float16
|
||||
assert plan.kwargs.get('shard_device') == gpu_device()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_optimal_search()
|
48
tests/test_elixir/test_search/test_simple.py
Normal file
48
tests/test_elixir/test_search/test_simple.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
def step_fn(model, inp):
|
||||
model(**inp).backward()
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_simple_search():
|
||||
model_fn, data_fn = TEST_MODELS.get('small')
|
||||
model = model_fn()
|
||||
data = data_fn()
|
||||
|
||||
sr = simple_search(model,
|
||||
1,
|
||||
split_number=5,
|
||||
shard_device=gpu_device(),
|
||||
prefetch=True,
|
||||
verbose=True,
|
||||
inp=data,
|
||||
step_fn=step_fn)
|
||||
|
||||
chunk_plans = deepcopy(sr.param_chunk_plans)
|
||||
private_plan = chunk_plans.pop(0)
|
||||
assert private_plan.name_list == ['embed.weight']
|
||||
assert private_plan.chunk_size == 320
|
||||
assert private_plan.kwargs.get('shard_device') == gpu_device()
|
||||
|
||||
assert chunk_plans[0].name_list == ['norm1.weight', 'norm1.bias']
|
||||
assert chunk_plans[1].name_list == ['mlp.proj1.weight', 'mlp.proj1.bias']
|
||||
assert chunk_plans[2].name_list == ['mlp.proj2.weight', 'mlp.proj2.bias']
|
||||
assert chunk_plans[3].name_list == ['norm2.weight']
|
||||
assert chunk_plans[4].name_list == ['norm2.bias']
|
||||
|
||||
for plan in chunk_plans:
|
||||
assert plan.chunk_size == 1088
|
||||
assert plan.kwargs.get('shard_device') == gpu_device()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_simple_search()
|
13
tests/test_elixir/test_src/test_move.py
Normal file
13
tests/test_elixir/test_src/test_move.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from colossalai.elixir.simulator import move_count
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_move_count():
|
||||
steps = [[0], [1, 2], [3], [3], [1, 2], [0]]
|
||||
size = 2
|
||||
assert move_count(steps, size) == 12
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_move_count()
|
23
tests/test_elixir/test_tools/test_registry.py
Normal file
23
tests/test_elixir/test_tools/test_registry.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import to_cuda
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_registry():
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
for name, model_tuple in TEST_MODELS:
|
||||
torch.cuda.synchronize()
|
||||
print(f'model `{name}` is in testing')
|
||||
|
||||
model_fn, data_fn = model_tuple
|
||||
model = model_fn().cuda()
|
||||
data = to_cuda(data_fn())
|
||||
loss = model(**data)
|
||||
loss.backward()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_registry()
|
46
tests/test_elixir/test_tracer/test_cuda_profiler.py
Normal file
46
tests/test_elixir/test_tracer/test_cuda_profiler.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.tracer.memory_tracer import cuda_memory_profiling
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def one_step(model, inp):
|
||||
loss = model(**inp)
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
|
||||
def try_one_model(model_fn, data_fn):
|
||||
model = model_fn().cuda()
|
||||
data = to_cuda(data_fn())
|
||||
one_step(model, data) # generate gradients
|
||||
|
||||
pre_cuda_alc = torch.cuda.memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
one_step(model, data)
|
||||
aft_cuda_alc = torch.cuda.max_memory_allocated()
|
||||
torch_activation_occ = aft_cuda_alc - pre_cuda_alc
|
||||
model.zero_grad(set_to_none=True)
|
||||
print('normal', torch_activation_occ)
|
||||
|
||||
before = torch.cuda.memory_allocated()
|
||||
profiling_dict = cuda_memory_profiling(model, data, one_step)
|
||||
after = torch.cuda.memory_allocated()
|
||||
print('profiling', profiling_dict)
|
||||
assert before == after
|
||||
assert torch_activation_occ == profiling_dict['activation_occ']
|
||||
print('Check is ok.')
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_cuda_profiler():
|
||||
model_list = ['resnet', 'gpt2_micro']
|
||||
for name in model_list:
|
||||
model_fn, data_fn = TEST_MODELS.get(name)
|
||||
try_one_model(model_fn, data_fn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cuda_profiler()
|
145
tests/test_elixir/test_tracer/test_op_cache.py
Normal file
145
tests/test_elixir/test_tracer/test_op_cache.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.tracer.memory_tracer import MTensor
|
||||
from colossalai.elixir.tracer.memory_tracer.op_cache import addmm_cache, bmm_cache, mm_cache
|
||||
from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
def op_mm(x, y):
|
||||
u = torch.matmul(x, y)
|
||||
return u.shape
|
||||
|
||||
|
||||
def op_addmm(x, y, z):
|
||||
u = torch.addmm(x, y, z)
|
||||
return u.shape
|
||||
|
||||
|
||||
def op_bmm(x, y):
|
||||
u = torch.bmm(x, y)
|
||||
return u.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_mm(dtype, size0=(4, 256), size1=(256, 1024)):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
assert get_cuda_allocated() == 0
|
||||
|
||||
x = torch.randn(size0, dtype=dtype, device='cuda')
|
||||
y = torch.randn(size1, dtype=dtype, device='cuda')
|
||||
torch_pre_alc = get_cuda_allocated()
|
||||
|
||||
torch_z_size = op_mm(x, y)
|
||||
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
|
||||
|
||||
del x
|
||||
del y
|
||||
|
||||
assert get_cuda_allocated() == 0
|
||||
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
|
||||
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
|
||||
op1_pre_alc = get_cuda_allocated()
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op1_z_size = op_mm(x, y)
|
||||
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op1_z_size
|
||||
assert torch_pre_alc == op1_pre_alc
|
||||
assert torch_temp_alc == op1_temp_alc
|
||||
assert len(mm_cache.temp_memory) > 0
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op2_z_size = op_mm(x, y)
|
||||
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op2_z_size
|
||||
assert torch_temp_alc == op2_temp_alc
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_addmm(dtype, size0=(4, 16), size1=(16, 64)):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
assert get_cuda_allocated() == 0
|
||||
|
||||
x = torch.randn(size0, dtype=dtype, device='cuda')
|
||||
y = torch.randn(size1, dtype=dtype, device='cuda')
|
||||
u = torch.randn(size1[-1], dtype=dtype, device='cuda')
|
||||
torch_pre_alc = get_cuda_allocated()
|
||||
|
||||
torch_z_size = op_addmm(u, x, y)
|
||||
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
|
||||
|
||||
del x
|
||||
del y
|
||||
del u
|
||||
|
||||
assert get_cuda_allocated() == 0
|
||||
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
|
||||
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
|
||||
u = MTensor(torch.randn(size1[-1], dtype=dtype, device='cuda'))
|
||||
op1_pre_alc = get_cuda_allocated()
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op1_z_size = op_addmm(u, x, y)
|
||||
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op1_z_size
|
||||
assert torch_pre_alc == op1_pre_alc
|
||||
assert torch_temp_alc == op1_temp_alc
|
||||
assert len(addmm_cache.temp_memory) > 0
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op2_z_size = op_addmm(u, x, y)
|
||||
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op2_z_size
|
||||
assert torch_temp_alc == op2_temp_alc
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_bmm(dtype, size0=(10, 4, 15), size1=(10, 15, 64)):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
assert get_cuda_allocated() == 0
|
||||
|
||||
x = torch.randn(size0, dtype=dtype, device='cuda')
|
||||
y = torch.randn(size1, dtype=dtype, device='cuda')
|
||||
torch_pre_alc = get_cuda_allocated()
|
||||
|
||||
torch_z_size = op_bmm(x, y)
|
||||
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
|
||||
|
||||
del x
|
||||
del y
|
||||
|
||||
assert get_cuda_allocated() == 0
|
||||
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
|
||||
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
|
||||
op1_pre_alc = get_cuda_allocated()
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op1_z_size = op_bmm(x, y)
|
||||
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op1_z_size
|
||||
assert torch_pre_alc == op1_pre_alc
|
||||
assert torch_temp_alc == op1_temp_alc
|
||||
assert len(bmm_cache.temp_memory) > 0
|
||||
|
||||
bmm_cache.print()
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op2_z_size = op_bmm(x, y)
|
||||
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op2_z_size
|
||||
assert torch_temp_alc == op2_temp_alc
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_addmm(dtype=torch.float)
|
37
tests/test_elixir/test_tracer/test_tf_order.py
Normal file
37
tests/test_elixir/test_tracer/test_tf_order.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from colossalai.elixir.tracer.param_tracer import generate_tf_order
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_tf_forward_backward():
|
||||
model_fn, data_fn = TEST_MODELS.get('gpt2_micro')
|
||||
model = model_fn()
|
||||
data = data_fn()
|
||||
|
||||
def forward_backward_fn(local_model, local_input):
|
||||
local_model(**local_input).backward()
|
||||
|
||||
# model.gradient_checkpointing_enable()
|
||||
tf_order = generate_tf_order(model, data, forward_backward_fn)
|
||||
params_per_step = tf_order['params_per_step']
|
||||
assert len(params_per_step) == 32
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
tf_order = generate_tf_order(model, data, forward_backward_fn)
|
||||
params_per_step = tf_order['params_per_step']
|
||||
checkpoint_info = tf_order['checkpoint_info']
|
||||
for i, step in enumerate(params_per_step):
|
||||
print(f'step {i}: {step}')
|
||||
for c in checkpoint_info:
|
||||
print(f'checkpoint info: {c}')
|
||||
assert len(params_per_step) == 44
|
||||
|
||||
assert data['input_ids'].device.type == 'cpu'
|
||||
assert data['attention_mask'].device.type == 'cpu'
|
||||
for param in model.parameters():
|
||||
assert param.device.type == 'cpu'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_tf_forward_backward()
|
94
tests/test_elixir/test_wrapper/test_amp.py
Normal file
94
tests/test_elixir/test_wrapper/test_amp.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def amp_check_model_states(ddp_optim, test_model):
|
||||
test_states = test_model.state_dict()
|
||||
for (name, _), p in zip(test_model.module.named_parameters(), amp.master_params(ddp_optim)):
|
||||
test_p = test_states[name]
|
||||
copy_p = p.to(test_p.device)
|
||||
print(f'checking parameter `{name}`: {test_p.dtype} {copy_p.dtype}')
|
||||
assert_close(test_p.data, copy_p.data)
|
||||
|
||||
|
||||
def exam_amp_one_model(model_fn, data_fn, nproc, group, exam_seed=2261):
|
||||
ddp_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(ddp_model)
|
||||
# important here, since apex has a lazy fp32 init after the first optimizer step
|
||||
test_model = test_model.half()
|
||||
|
||||
ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0)
|
||||
ddp_model, ddp_optim = amp.initialize(ddp_model,
|
||||
ddp_optim,
|
||||
opt_level='O2',
|
||||
loss_scale=1.0,
|
||||
keep_batchnorm_fp32=False)
|
||||
ddp_model = DDP(ddp_model, message_size=0, allreduce_always_fp32=True)
|
||||
print("ok")
|
||||
exit(0)
|
||||
test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0)
|
||||
sr = simple_search(test_model, nproc, shard_device=gpu_device(), unified_dtype=torch.float16, verbose=True)
|
||||
test_model = ElixirModule(test_model, sr, group, dtype=torch.float16, reduce_always_fp32=True, output_fp32=True)
|
||||
test_optim = ElixirOptimizer(test_model, test_optim, initial_scale=1.0)
|
||||
|
||||
# get different data
|
||||
seed_all(exam_seed + dist.get_rank(group), cuda_deterministic=True)
|
||||
for _ in range(2):
|
||||
data = to_cuda(data_fn())
|
||||
|
||||
ddp_optim.zero_grad()
|
||||
ddp_loss = ddp_model(**data)
|
||||
with amp.scale_loss(ddp_loss, ddp_optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
ddp_optim.step()
|
||||
|
||||
test_optim.zero_grad()
|
||||
test_loss = test_model(**data)
|
||||
test_optim.backward(test_loss)
|
||||
test_optim.step()
|
||||
|
||||
assert_close(ddp_loss, test_loss)
|
||||
amp_check_model_states(ddp_optim, test_model)
|
||||
|
||||
|
||||
def exam_amp_in_models(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('gpt2_micro')
|
||||
exam_amp_one_model(model_fn, data_fn, nproc, group)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_amp_in_models(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_elixir_amp(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_elixir_amp(world_size=2)
|
95
tests/test_elixir/test_wrapper/test_module.py
Normal file
95
tests/test_elixir/test_wrapper/test_module.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.elixir.wrapper import ElixirModule
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, assert_dict_values, to_cuda
|
||||
|
||||
|
||||
def check_gradient(ddp_model: nn.Module, test_model: ElixirModule):
|
||||
grad_state = test_model.state_dict(from_param=True)
|
||||
for name, param in ddp_model.named_parameters():
|
||||
assert_close(param.grad.cpu(), grad_state[name])
|
||||
|
||||
|
||||
def exam_module_init(nproc, group, grad_flag):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
torch_model = model_fn().cuda()
|
||||
test_model = model_fn().cuda()
|
||||
|
||||
for p1, p2 in zip(torch_model.parameters(), test_model.parameters()):
|
||||
p1.requires_grad = p2.requires_grad = grad_flag
|
||||
|
||||
sr = simple_search(test_model, nproc)
|
||||
model = ElixirModule(test_model, sr, group)
|
||||
# check function: ElixirModule.load_state_dict after ElixirModule.__init__
|
||||
torch_st = torch_model.state_dict()
|
||||
if dist.get_rank() != 0:
|
||||
torch_st = None
|
||||
test_st = model.load_state_dict(torch_st, only_rank_0=True)
|
||||
# check function: ElixirModule.state_dict after ElixirModule.__init__
|
||||
torch_st = torch_model.state_dict()
|
||||
test_st = model.state_dict()
|
||||
assert_dict_values(torch_st, test_st, fn=torch.equal)
|
||||
|
||||
|
||||
def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2261):
|
||||
ddp_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(ddp_model)
|
||||
sr = simple_search(test_model, nproc, allocate_factor=0.6)
|
||||
test_model = ElixirModule(test_model, sr, group)
|
||||
|
||||
# get different data
|
||||
seed_all(exam_seed + dist.get_rank(group))
|
||||
data = data_fn()
|
||||
data = to_cuda(data)
|
||||
|
||||
seed_all(exam_seed, cuda_deterministic=True)
|
||||
ddp_model = DDP(ddp_model)
|
||||
ddp_loss = ddp_model(**data)
|
||||
ddp_loss.backward()
|
||||
|
||||
test_loss = test_model(**data)
|
||||
test_model.backward(test_loss)
|
||||
|
||||
assert_close(ddp_loss, test_loss)
|
||||
check_gradient(ddp_model.module, test_model)
|
||||
|
||||
|
||||
def exam_modules_fwd_bwd(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=False)
|
||||
exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=True)
|
||||
exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_elixir_module(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_elixir_module(world_size=2)
|
77
tests/test_elixir/test_wrapper/test_optimizer.py
Normal file
77
tests/test_elixir/test_wrapper/test_optimizer.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, allclose, assert_dict_values, to_cuda
|
||||
|
||||
|
||||
def exam_optimizer_one_model(model_fn, data_fn, nproc, group, exam_seed=2261):
|
||||
ddp_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(ddp_model)
|
||||
|
||||
ddp_model = DDP(ddp_model)
|
||||
ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0)
|
||||
|
||||
test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0)
|
||||
sr = simple_search(test_model, nproc, shard_device=gpu_device())
|
||||
test_model = ElixirModule(test_model, sr, group)
|
||||
test_optim = ElixirOptimizer(test_model, test_optim)
|
||||
|
||||
# get different data
|
||||
seed_all(exam_seed + dist.get_rank(group))
|
||||
data = to_cuda(data_fn())
|
||||
|
||||
seed_all(exam_seed, cuda_deterministic=True)
|
||||
ddp_optim.zero_grad()
|
||||
ddp_loss = ddp_model(**data)
|
||||
ddp_loss.backward()
|
||||
ddp_optim.step()
|
||||
|
||||
test_optim.zero_grad()
|
||||
test_loss = test_model(**data)
|
||||
test_optim.backward(test_loss)
|
||||
test_optim.step()
|
||||
|
||||
assert_close(ddp_loss, test_loss)
|
||||
torch_st = ddp_model.module.state_dict()
|
||||
test_st = test_model.state_dict()
|
||||
assert_dict_values(torch_st, test_st, fn=partial(allclose, rtol=2e-6, atol=2e-5))
|
||||
|
||||
|
||||
def exam_optimizer_in_models(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
exam_optimizer_one_model(model_fn, data_fn, nproc, group)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_optimizer_in_models(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_elixir_optimizer(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_elixir_optimizer(world_size=4)
|
89
tests/test_elixir/test_wrapper/test_prefetch.py
Normal file
89
tests/test_elixir/test_wrapper/test_prefetch.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.elixir.wrapper import ElixirModule
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def check_gradient(ddp_model: nn.Module, test_model: ElixirModule):
|
||||
grad_state = test_model.state_dict(from_param=True)
|
||||
for name, param in ddp_model.named_parameters():
|
||||
assert_close(param.grad.cpu(), grad_state[name])
|
||||
|
||||
|
||||
def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2263):
|
||||
|
||||
def one_step(local_model, local_input):
|
||||
loss = local_model(**local_input)
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
ddp_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(ddp_model)
|
||||
|
||||
# get different data
|
||||
seed_all(exam_seed + dist.get_rank(group))
|
||||
data = to_cuda(data_fn())
|
||||
|
||||
# wrap as DDP model
|
||||
ddp_model = DDP(ddp_model)
|
||||
# search how to initialize chunks
|
||||
sr = simple_search(test_model,
|
||||
nproc,
|
||||
shard_device=gpu_device(),
|
||||
prefetch=True,
|
||||
verbose=True,
|
||||
inp=data,
|
||||
step_fn=one_step)
|
||||
test_model = ElixirModule(test_model, sr, group, prefetch=True)
|
||||
|
||||
seed_all(exam_seed, cuda_deterministic=True)
|
||||
ddp_loss = one_step(ddp_model, data)
|
||||
|
||||
with torch.no_grad():
|
||||
test_loss = test_model(**data)
|
||||
assert_close(ddp_loss, test_loss)
|
||||
|
||||
test_loss = test_model(**data)
|
||||
test_model.backward(test_loss)
|
||||
assert_close(ddp_loss, test_loss)
|
||||
check_gradient(ddp_model.module, test_model)
|
||||
|
||||
|
||||
def exam_modules_fwd_bwd(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_module_prefetch(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_module_prefetch(world_size=2)
|
41
tests/test_elixir/utils/__init__.py
Normal file
41
tests/test_elixir/utils/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from . import gpt, mlp, opt, resnet, small
|
||||
from .registry import TEST_MODELS
|
||||
|
||||
|
||||
def to_cuda(input_dict):
|
||||
|
||||
def local_fn(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = t.cuda()
|
||||
return t
|
||||
|
||||
ret = tree_map(local_fn, input_dict)
|
||||
return ret
|
||||
|
||||
|
||||
def allclose(ta, tb, **kwargs):
|
||||
assert_close(ta, tb, **kwargs)
|
||||
return True
|
||||
|
||||
|
||||
def assert_dict_keys(test_dict, keys):
|
||||
assert len(test_dict) == len(keys)
|
||||
for k in keys:
|
||||
assert k in test_dict
|
||||
|
||||
|
||||
def assert_dict_values(da, db, fn):
|
||||
assert len(da) == len(db)
|
||||
for k, v in da.items():
|
||||
assert k in db
|
||||
if not torch.is_tensor(v):
|
||||
continue
|
||||
u = db.get(k)
|
||||
if u.device != v.device:
|
||||
v = v.to(u.device)
|
||||
# print(f"checking key {k}: {u.shape} vs {v.shape}")
|
||||
assert fn(u.data, v.data), f'max diff {torch.max(torch.abs(u.data - v.data))}'
|
79
tests/test_elixir/utils/gpt.py
Normal file
79
tests/test_elixir/utils/gpt.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
MICRO_VS = 128
|
||||
MICRO_BS = 4
|
||||
MICRO_SL = 64
|
||||
|
||||
MACRO_VS = 50257
|
||||
MACRO_BS = 2
|
||||
MACRO_SL = 1024
|
||||
|
||||
|
||||
def micro_data_fn():
|
||||
input_ids = torch.randint(low=0, high=MICRO_VS, size=(MICRO_BS, MICRO_SL))
|
||||
attn_mask = torch.ones_like(input_ids)
|
||||
return dict(input_ids=input_ids, attention_mask=attn_mask)
|
||||
|
||||
|
||||
def small_data_fn():
|
||||
input_ids = torch.randint(low=0, high=MACRO_VS, size=(MACRO_BS, MACRO_SL))
|
||||
attn_mask = torch.ones_like(input_ids)
|
||||
return dict(input_ids=input_ids, attention_mask=attn_mask)
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257):
|
||||
super().__init__()
|
||||
self.enable_gc = False
|
||||
self.config = GPT2Config(
|
||||
# pre-commit: do not rearrange
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0)
|
||||
self.module = GPT2LMHeadModel(config=self.config)
|
||||
self.criterion = GPTLMLoss()
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
self.module.gradient_checkpointing_enable()
|
||||
self.enable_gc = True
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
output = self.module(input_ids=input_ids, attention_mask=attention_mask, use_cache=(not self.enable_gc))[0]
|
||||
loss = self.criterion(output, input_ids)
|
||||
return loss
|
||||
|
||||
|
||||
gpt2_micro = partial(GPTLMModel, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128)
|
||||
gpt2_small = GPTLMModel
|
||||
gpt2_base = partial(GPTLMModel, hidden_size=1024, num_layers=24, num_attention_heads=16)
|
||||
|
||||
TEST_MODELS.register('gpt2_micro', gpt2_micro, micro_data_fn)
|
||||
TEST_MODELS.register('gpt2_small', gpt2_small, small_data_fn)
|
||||
TEST_MODELS.register('gpt2_base', gpt2_base, small_data_fn)
|
34
tests/test_elixir/utils/mlp.py
Normal file
34
tests/test_elixir/utils/mlp.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
|
||||
def mlp_data_fn():
|
||||
return dict(x=torch.randn(4, 16))
|
||||
|
||||
|
||||
class MlpModule(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.proj1 = nn.Linear(hidden_dim, 4 * hidden_dim)
|
||||
self.act = nn.GELU()
|
||||
self.proj2 = nn.Linear(4 * hidden_dim, hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return x + (self.proj2(self.act(self.proj1(x))))
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.mlp = MlpModule(hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.mlp(x)
|
||||
return output.sum()
|
||||
|
||||
|
||||
TEST_MODELS.register('mlp', MlpModel, mlp_data_fn)
|
46
tests/test_elixir/utils/opt.py
Normal file
46
tests/test_elixir/utils/opt.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch.nn as nn
|
||||
from transformers import OPTConfig, OPTForCausalLM
|
||||
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
from .gpt import micro_data_fn
|
||||
|
||||
|
||||
class OPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.module = OPTForCausalLM(config=config)
|
||||
self.enable_gc = False
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
self.module.gradient_checkpointing_enable()
|
||||
self.enable_gc = True
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
loss = self.module(
|
||||
# pre-commit: do not rearrange
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=input_ids,
|
||||
use_cache=(not self.enable_gc))['loss']
|
||||
return loss
|
||||
|
||||
|
||||
def opt_micro():
|
||||
opt_config = OPTConfig(
|
||||
# pre-commit: do not rearrange
|
||||
vocab_size=128,
|
||||
activation_dropout=0.0,
|
||||
dropout=0,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=4,
|
||||
ffn_dim=128,
|
||||
num_attention_heads=4,
|
||||
word_embed_proj_dim=32,
|
||||
output_projection=True)
|
||||
return OPTLMModel(opt_config)
|
||||
|
||||
|
||||
TEST_MODELS.register('opt_micro', opt_micro, micro_data_fn)
|
26
tests/test_elixir/utils/registry.py
Normal file
26
tests/test_elixir/utils/registry.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class Registry(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._registry_dict = OrderedDict()
|
||||
|
||||
def register(self, name: str, model_fn: Callable, data_fn: Callable):
|
||||
assert name not in self._registry_dict
|
||||
|
||||
model_tuple = (model_fn, data_fn)
|
||||
self._registry_dict[name] = model_tuple
|
||||
|
||||
def get(self, name: str):
|
||||
return self._registry_dict[name]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._registry_dict.items())
|
||||
|
||||
|
||||
TEST_MODELS = Registry()
|
||||
|
||||
__all__ = [TEST_MODELS]
|
23
tests/test_elixir/utils/resnet.py
Normal file
23
tests/test_elixir/utils/resnet.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models import resnet18
|
||||
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
|
||||
def resnet_data_fn():
|
||||
return dict(x=torch.randn(4, 3, 32, 32))
|
||||
|
||||
|
||||
class ResNetModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.r = resnet18()
|
||||
|
||||
def forward(self, x):
|
||||
output = self.r(x)
|
||||
return output.sum()
|
||||
|
||||
|
||||
TEST_MODELS.register('resnet', ResNetModel, resnet_data_fn)
|
31
tests/test_elixir/utils/small.py
Normal file
31
tests/test_elixir/utils/small.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from tests.test_elixir.utils.mlp import MlpModule
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
|
||||
def small_data_fn():
|
||||
return dict(x=torch.randint(low=0, high=20, size=(4, 8)))
|
||||
|
||||
|
||||
class SmallModel(nn.Module):
|
||||
|
||||
def __init__(self, num_embeddings: int = 20, hidden_dim: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(num_embeddings, hidden_dim)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim)
|
||||
self.mlp = MlpModule(hidden_dim=hidden_dim)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim)
|
||||
self.proj = nn.Linear(hidden_dim, num_embeddings, bias=False)
|
||||
self.proj.weight = self.embed.weight
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embed(x)
|
||||
x = x + self.norm1(self.mlp(x))
|
||||
x = self.proj(self.norm2(x))
|
||||
x = x.mean(dim=-2)
|
||||
return x.sum()
|
||||
|
||||
|
||||
TEST_MODELS.register('small', SmallModel, small_data_fn)
|
Reference in New Issue
Block a user