[elixir] add elixir and its unit tests (#3835)

* [elixir] add elixir

* [elixir] add unit tests

* remove useless code

* fix python 3.8 issue

* fix typo

* add test skip

* add docstrings

* add docstrings

* add readme

* fix typo
This commit is contained in:
Haichen Huang
2023-05-29 09:32:37 +08:00
committed by GitHub
parent 34966378e8
commit 206280408a
86 changed files with 6627 additions and 2 deletions

View File

View File

@@ -0,0 +1,89 @@
import torch
import torch.distributed as dist
import colossalai
from colossalai.elixir import ElixirModule, ElixirOptimizer
from colossalai.elixir.search import minimum_waste_search
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def check_elixir_compatibility(early_stop: bool = True):
"""check gemini plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
passed_models = []
failed_info = {} # (model_name, error) pair
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
# These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
'torchaudio_wav2vec2_base', 'torchaudio_hubert_base', 'torchvision_convnext_base'):
continue
try:
print(name)
global_size = dist.get_world_size()
global_group = dist.GroupMember.WORLD
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
sr = minimum_waste_search(
# pre-commit: do not rearrange
m=model,
group_size=global_size,
unified_dtype=torch.float16,
prefetch=False,
verbose=True)
model = ElixirModule(model, sr, global_group, prefetch=False, dtype=torch.float16)
optimizer = ElixirOptimizer(model, optimizer, initial_scale=32)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
optimizer.backward(loss)
optimizer.step()
passed_models.append(name)
del model, optimizer, criterion, data, output, loss
except Exception as e:
failed_info[name] = e
if early_stop:
raise e
torch.cuda.empty_cache()
if dist.get_rank() == 0:
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_elixir_compatibility(early_stop=early_stop)
@rerun_if_address_is_in_use()
def exam_compatibility(early_stop: bool = True):
spawn(run_dist, 2, early_stop=early_stop)
if __name__ == '__main__':
exam_compatibility(early_stop=False)

View File

View File

@@ -0,0 +1,72 @@
from functools import partial
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.chunk.scheduler import FIFOScheduler
from colossalai.elixir.hook import BufferStore, HookParam
from colossalai.elixir.tensor import OutplaceTensor
def to_divide(a: int, b: int):
return a + (-a % b)
def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher):
empty_grad = torch.empty_like(grad)
empty_grad.storage().resize_(0)
with torch._C.DisableTorchFunction():
chunk = fetcher.get_one_chunk(param)
if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError()
fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(param, grad)
fetcher.reduce_chunk(chunk)
return empty_grad
def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo):
pg_size = dist.get_world_size(process_group)
private_list = list()
for param in model.parameters():
block_size = to_divide(param.numel(), pg_size)
private_list.append(BlockRequire(block_size, param.dtype))
mp = MemoryPool('cuda')
mp.allocate(private_block_list=private_list)
cg = ChunkGroup(rcache=mp)
# allocate chunk group
fused_config = dict(rcache_fused=True)
for param in model.parameters():
cg.allocate_chunk([param], to_divide(param.numel(), pg_size), param.dtype, process_group, fused_config)
# initialize chunk fetcher
scheduler = FIFOScheduler()
fetcher = ChunkFetcher(scheduler, cg)
buffer = BufferStore(0, torch.float32)
# register fetcher and gradient handler
HookParam.attach_fetcher(fetcher, buffer)
for param in model.parameters():
param.register_hook(partial(grad_handler, param=param, fetcher=fetcher))
param.__class__ = HookParam
# set inplace to False for all modules
for module in model.modules():
if hasattr(module, 'inplace'):
module.inplace = False
def transform_input(self_module, inputs):
fetcher.reset()
input_list = list()
for t in inputs:
if isinstance(t, torch.Tensor):
t = OutplaceTensor(t)
input_list.append(t)
return tuple(input_list)
model.register_forward_pre_hook(transform_input)
return model, cg

View File

@@ -0,0 +1,63 @@
import torch
from colossalai.elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock
from colossalai.testing import run_on_environment_flag
@run_on_environment_flag('ELX')
def test_block():
b = PublicBlock(123, torch.float16, 'cuda')
payload_b = b.payload
assert payload_b.numel() == 123
assert payload_b.dtype == torch.float16
assert payload_b.device.type == 'cuda'
assert payload_b.numel() * payload_b.element_size() == b.memo_occ
c = PrivateBlock(77, torch.float, 'cpu')
payload_c = c.payload
assert payload_c.numel() == 77
assert payload_c.dtype == torch.float
assert payload_c.device.type == 'cpu'
assert payload_c.numel() * payload_c.element_size() == c.memo_occ
print('test_block: ok')
@run_on_environment_flag('ELX')
def test_memory_pool():
mp = MemoryPool(device_type='cuda')
private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)]
mp.allocate(public_block_number=4, private_block_list=private_list)
block0 = mp.get_public_block()
assert block0 in mp.public_used_blocks
assert mp.public_used_cnt == 1
assert mp.public_free_cnt == 3
block1 = mp.get_public_block()
assert block1 in mp.public_used_blocks
assert mp.public_used_cnt == 2
assert mp.public_free_cnt == 2
mp.free_public_block(block0)
mp.free_public_block(block1)
assert block0 in mp.public_free_blocks
assert block1 in mp.public_free_blocks
assert mp.public_used_cnt == 0
assert mp.public_free_cnt == 4
block0 = mp.get_private_block(5, torch.float)
assert block0.numel == 5
assert block0.dtype == torch.float
print('test_memory_pool: ok')
if __name__ == '__main__':
test_block()
test_memory_pool()

View File

@@ -0,0 +1,155 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
def exam_chunk_functions(nproc, group):
a = torch.randn(2, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 2, 128, device='cuda')
copy_b = b.clone()
c = torch.randn(128, device='cuda')
copy_c = c.clone()
d = torch.randn(4, 32, device='cuda')
copy_d = d.clone()
mp = MemoryPool('cuda')
mp.allocate(public_block_number=1)
chunk = Chunk(mp, 1024, torch.float, group)
chunk.l2_norm_flag = True
assert chunk.chunk_size == 1024
assert chunk.chunk_dtype == torch.float
assert chunk.shard_size == 1024 // nproc
def check_tensors():
assert torch.equal(a, copy_a)
assert torch.equal(b, copy_b)
assert torch.equal(c, copy_c)
assert torch.equal(d, copy_d)
chunk.append_tensor(a)
chunk.append_tensor(b)
chunk.append_tensor(c)
chunk.append_tensor(d)
check_tensors()
chunk.close_chunk()
assert chunk.is_replica is False
# check function: get_cpu_copy
cpu_copys = chunk.get_cpu_copy()
for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys):
assert t_cpu.device.type == 'cpu'
assert torch.equal(t_gpu.cpu(), t_cpu)
# check function: access_chunk
block = mp.get_public_block()
chunk.access_chunk(block)
assert chunk.is_replica
assert chunk.scatter_check
check_tensors()
# check function: release_chunk
chunk.optim_sync_flag = False
block = chunk.release_chunk()
assert block in mp.public_used_blocks
assert chunk.is_replica is False
assert chunk.optim_sync_flag is True
# check function: access_chunk after release_chunk
chunk.access_chunk(block)
check_tensors()
# check function: reduce_chunk
norm = block.payload.float().norm(2)**2
chunk.reduce_chunk()
assert chunk.is_replica is False
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
test_norm = torch.Tensor([chunk.l2_norm]).cuda()
dist.all_reduce(test_norm)
assert torch.allclose(norm, test_norm)
torch.cuda.synchronize()
print('chunk functions are ok')
def exam_chunk_states(nproc, group):
a = torch.randn(2, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 2, 128, device='cuda')
copy_b = b.clone()
c = torch.randn(128, device='cuda')
copy_c = c.clone()
d = torch.randn(4, 32, device='cuda')
copy_d = d.clone()
private = [BlockRequire(1024, torch.float)]
mp = MemoryPool('cuda')
mp.allocate(private_block_list=private)
chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True)
assert chunk.chunk_size == 1024
assert chunk.chunk_dtype == torch.float
assert chunk.shard_size == 1024 // nproc
def check_tensors():
assert torch.equal(a, copy_a)
assert torch.equal(b, copy_b)
assert torch.equal(c, copy_c)
assert torch.equal(d, copy_d)
chunk.append_tensor(a)
chunk.append_tensor(b)
chunk.append_tensor(c)
chunk.append_tensor(d)
check_tensors()
chunk.close_chunk()
assert chunk.is_replica is False
chunk.access_chunk()
assert chunk.is_replica
check_tensors()
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
chunk.tensor_trans_state(a, TensorState.COMPUTE)
assert chunk.tensor_state_cnter[TensorState.HOLD] == 3
assert chunk.tensor_state_cnter[TensorState.COMPUTE] == 1
tensor_list = [a, b, c, d]
for t in tensor_list:
chunk.tensor_trans_state(t, TensorState.COMPUTE)
chunk.tensor_trans_state(t, TensorState.HOLD_AFTER_BWD)
chunk.tensor_trans_state(t, TensorState.READY_FOR_REDUCE)
assert chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
assert chunk.reduce_check
torch.cuda.synchronize()
print('chunk states are ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD)
exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_functions(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_functions(world_size=4)

View File

@@ -0,0 +1,71 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.chunk import ChunkGroup
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.test_chunk.fetcher_utils import hook_transform
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def check_gradient(ddp_model, my_model, cg: ChunkGroup):
for chunk in cg.fused_chunks:
cg.access_chunk(chunk)
for (name, p0), p1 in zip(ddp_model.named_parameters(), my_model.parameters()):
torch.cuda.synchronize()
print(f'checking parameter {name}')
assert_close(p0.grad.data, p1.data)
def exam_chunk_fetcher(nproc, group):
model_fn, data_fn = TEST_MODELS.get('resnet')
torch_model = model_fn().cuda()
test_model = copy.deepcopy(torch_model)
rank = dist.get_rank(group)
# get different data
seed_all(1001 + rank)
data = to_cuda(data_fn())
seed_all(1001, cuda_deterministic=True)
ddp_model = DDP(torch_model)
ddp_loss = ddp_model(**data)
ddp_loss.backward()
hook_model, cg = hook_transform(test_model, group)
my_loss = hook_model(**data)
my_loss.backward()
assert_close(ddp_loss, my_loss)
check_gradient(ddp_model, hook_model, cg)
print('private chunk fetcher is ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_fetcher(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_fetcher(world_size=2)

View File

@@ -0,0 +1,98 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
def exam_chunk_group_functions(nproc, group):
a = torch.randn(3, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 32, device='cuda')
copy_b = b.clone()
c = torch.randn(256, device='cuda')
copy_c = c.clone()
d = torch.randn(2, 2, 64, device='cuda')
copy_d = d.clone()
e = torch.randn(2, 33, device='cuda')
copy_e = e.clone()
mp = MemoryPool('cuda')
mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)])
cg = ChunkGroup(rcache=mp)
c0 = cg.allocate_chunk([a, b], 256, torch.float, group)
c1 = cg.allocate_chunk([c], 256, torch.float, group)
c2 = cg.allocate_chunk([d], 256, torch.float, group)
fused_config = dict(rcache_fused=True)
c3 = cg.allocate_chunk([e], 68, torch.float, group, fused_config)
def check_chunk_0():
assert torch.equal(a, copy_a)
assert torch.equal(b, copy_b)
def check_chunk_1():
assert torch.equal(c, copy_c)
def check_chunk_2():
assert torch.equal(d, copy_d)
def check_chunk_3():
assert torch.equal(e, copy_e)
# check tensors_to_chunks
chunks = cg.tensors_to_chunks([e, a])
assert chunks[0] == c0
assert chunks[1] == c3
# check access_chunk for unfused chunks
cg.access_chunk(c0)
cg.access_chunk(c1)
check_chunk_0()
check_chunk_1()
assert not cg.rcache_enough_check(c2)
assert cg.rcache_enough_check(c3)
# check access_chunk for fused chunks
cg.access_chunk(c3)
check_chunk_3()
# check release_chunk for unfused chunks
cg.release_chunk(c1)
assert cg.rcache_enough_check(c2)
# check access_chunk
cg.access_chunk(c2)
check_chunk_2()
cg.tensor_trans_state(e, TensorState.COMPUTE)
cg.tensor_trans_state(e, TensorState.HOLD_AFTER_BWD)
cg.tensor_trans_state(e, TensorState.READY_FOR_REDUCE)
cg.reduce_chunk(c3)
assert not c3.is_replica
torch.cuda.synchronize()
print('chunk group functions are ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_group(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_group(world_size=2)

View File

@@ -0,0 +1,130 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import Chunk, MemoryPool
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
def exam_fifo(nproc, group):
mp = MemoryPool('cuda')
mp.allocate(public_block_number=1)
c0 = Chunk(mp, 1024, torch.float, group)
c1 = Chunk(mp, 1024, torch.float, group)
c2 = Chunk(mp, 1024, torch.float, group)
sdl = FIFOScheduler()
sdl.reset()
sdl.add(c0)
sdl.add(c1)
sdl.add(c2)
sdl.add(c0) # nothing happens here
assert sdl.top() == c0
sdl.remove(c0)
assert sdl.top() == c1, f'{sdl.top()}'
sdl.remove(c0)
assert sdl.top() == c1, f'{sdl.top()}'
sdl.add(c0)
assert sdl.top() == c1
sdl.remove(c1)
assert sdl.top() == c2
sdl.remove(c2)
assert sdl.top() == c0
def exam_prefetch(nproc, group):
mp = MemoryPool('cuda')
mp.allocate()
c0 = Chunk(mp, 1024, torch.float, group)
c1 = Chunk(mp, 1024, torch.float, group)
c2 = Chunk(mp, 1024, torch.float, group)
chunk_called_per_step = [[c0], [c1], [c2], [c0], [c0], [c1], [c2], [c2], [c1], [c0]]
sdl = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step)
print(sdl.next_step_dict)
sdl.reset()
sdl.step()
sdl.add(c0)
assert sdl.top() == c0
sdl.step()
sdl.add(c1)
assert sdl.top() == c1
sdl.step()
sdl.add(c2)
assert sdl.top() == c2
sdl.remove(c0)
sdl.step()
sdl.add(c0)
assert sdl.top() == c2
sdl.remove(c0)
sdl.step()
sdl.add(c0)
assert sdl.top() == c0
sdl.remove(c0) # notice here
sdl.remove(c1)
sdl.step()
sdl.add(c1)
assert sdl.top() == c1
sdl.remove(c2)
sdl.step()
sdl.add(c2)
assert sdl.top() == c1
sdl.remove(c2)
sdl.step()
sdl.add(c2)
assert sdl.top() == c2
sdl.remove(c2) # notice here
sdl.add(c0) # notice here
sdl.remove(c1)
sdl.step()
sdl.add(c1)
assert sdl.top() == c1
sdl.remove(c1) # notice here
sdl.remove(c0)
sdl.step()
sdl.add(c0)
assert sdl.top() == c0
sdl.remove(c0)
sdl.clear()
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD)
exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@run_on_environment_flag('ELX')
def test_chunk_scheduler(world_size=1):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_scheduler()

View File

@@ -0,0 +1,22 @@
from colossalai.elixir.ctx import MetaContext
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
@run_on_environment_flag('ELX')
def test_meta_context():
builder, *_ = TEST_MODELS.get('resnet')
with MetaContext():
model = builder()
for name, param in model.named_parameters():
assert param.device.type == 'meta'
print(name, param)
for name, buffer in model.named_buffers():
assert buffer.device.type == 'meta'
print(name, buffer)
if __name__ == '__main__':
test_meta_context()

View File

@@ -0,0 +1,56 @@
from copy import deepcopy
import torch
import torch.nn as nn
from colossalai.elixir.hook import BufferStore, HookParam
from colossalai.elixir.tensor import FakeTensor
def test_hook():
x = nn.Parameter(torch.randn(4, 4))
ori_numel = x.numel()
ori_size = x.size()
ori_stride = x.stride()
ori_offset = x.storage_offset()
fake_data = FakeTensor(x.data)
x.data = fake_data
x.__class__ = HookParam
assert x.numel() == ori_numel
assert x.size() == ori_size
assert x.stride() == ori_stride
assert x.storage_offset() == ori_offset
def test_store():
buffer = BufferStore(1024, torch.float16)
print(buffer)
x = torch.randn(4, 128, dtype=torch.float16, device='cuda')
original_ptr_x = x.data_ptr()
copy_x = deepcopy(x)
y = torch.randn(512, dtype=torch.float16, device='cuda')
original_ptr_y = y.data_ptr()
copy_y = deepcopy(y)
offset = 0
offset = buffer.insert(x, offset)
assert offset == x.numel()
assert torch.equal(x, copy_x)
offset = buffer.insert(y, offset)
assert offset == 1024
assert torch.equal(y, copy_y)
buffer.erase(x)
buffer.erase(y)
assert x.data_ptr() == original_ptr_x
assert y.data_ptr() == original_ptr_y
if __name__ == '__main__':
test_store()

View File

@@ -0,0 +1,35 @@
from copy import deepcopy
import pytest
from torch.testing import assert_close
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def exam_one_model(model_fn, data_fn):
from colossalai.elixir.kernels.attn_wrapper import wrap_attention
torch_model = model_fn().cuda()
test_model = deepcopy(torch_model)
test_model = wrap_attention(test_model)
data = to_cuda(data_fn())
torch_out = torch_model(**data)
torch_out.backward()
test_out = test_model(**data)
test_out.backward()
assert_close(torch_out, test_out)
for (name, p_torch), p_test in zip(torch_model.named_parameters(), test_model.parameters()):
assert_close(p_torch.grad, p_test.grad)
@pytest.mark.skip(reason="Need to install xformers")
def test_gpt_atten_kernel():
exam_one_model(*TEST_MODELS.get('gpt2_micro'))
exam_one_model(*TEST_MODELS.get('opt_micro'))
if __name__ == '__main__':
test_gpt_atten_kernel()

View File

@@ -0,0 +1,58 @@
import os
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed
from colossalai.elixir.wrapper import ElixirModule
def exam_fused_layernorm(nproc, group):
torch_model = nn.LayerNorm(2048)
fused_model = deepcopy(torch_model)
torch_model = torch_model.cuda()
sr = simple_search(fused_model, nproc, 1, 1.0, verbose=True)
fused_model = ElixirModule(fused_model, sr, group, use_fused_kernels=True)
data = torch.randn(2, 2048, device='cuda')
torch_loss = torch_model(data).sum()
torch_loss.backward()
fused_loss = fused_model(data).sum()
fused_model.backward(fused_loss)
assert_close(torch_loss, fused_loss)
grad_state = fused_model.state_dict(from_param=True)
for name, param in torch_model.named_parameters():
assert_close(param.grad.cpu(), grad_state[name])
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_fused_layernorm(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1])
@pytest.mark.skip(reason='need to install apex')
def test_fused_layernorm(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_fused_layernorm(world_size=1)

View File

@@ -0,0 +1,38 @@
from copy import deepcopy
import torch
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import minimum_waste_search
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
def step_fn(model, inp):
model(**inp).backward()
@run_on_environment_flag('ELX')
def test_mini_waste_search():
model_fn, data_fn = TEST_MODELS.get('gpt2_small')
model = model_fn()
data = data_fn()
sr = minimum_waste_search(model,
1,
unified_dtype=torch.float16,
cpu_offload=True,
prefetch=True,
verbose=True,
inp=data,
step_fn=step_fn)
chunk_plans = deepcopy(sr.param_chunk_plans)
for plan in chunk_plans:
assert plan.chunk_dtype == torch.float16
assert plan.kwargs.get('shard_device') == torch.device('cpu')
assert plan.kwargs.get('cpu_pin_memory') == True
if __name__ == '__main__':
test_mini_waste_search()

View File

@@ -0,0 +1,30 @@
from copy import deepcopy
import torch
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import optimal_search
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
def step_fn(model, inp):
model(**inp).backward()
@run_on_environment_flag('ELX')
def test_optimal_search():
model_fn, data_fn = TEST_MODELS.get('gpt2_small')
model = model_fn()
data = data_fn()
sr = optimal_search(model, 1, unified_dtype=torch.float16, overlap=True, verbose=True, inp=data, step_fn=step_fn)
chunk_plans = deepcopy(sr.param_chunk_plans)
for plan in chunk_plans:
assert plan.chunk_dtype == torch.float16
assert plan.kwargs.get('shard_device') == gpu_device()
if __name__ == '__main__':
test_optimal_search()

View File

@@ -0,0 +1,48 @@
from copy import deepcopy
import torch
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import simple_search
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
def step_fn(model, inp):
model(**inp).backward()
@run_on_environment_flag('ELX')
def test_simple_search():
model_fn, data_fn = TEST_MODELS.get('small')
model = model_fn()
data = data_fn()
sr = simple_search(model,
1,
split_number=5,
shard_device=gpu_device(),
prefetch=True,
verbose=True,
inp=data,
step_fn=step_fn)
chunk_plans = deepcopy(sr.param_chunk_plans)
private_plan = chunk_plans.pop(0)
assert private_plan.name_list == ['embed.weight']
assert private_plan.chunk_size == 320
assert private_plan.kwargs.get('shard_device') == gpu_device()
assert chunk_plans[0].name_list == ['norm1.weight', 'norm1.bias']
assert chunk_plans[1].name_list == ['mlp.proj1.weight', 'mlp.proj1.bias']
assert chunk_plans[2].name_list == ['mlp.proj2.weight', 'mlp.proj2.bias']
assert chunk_plans[3].name_list == ['norm2.weight']
assert chunk_plans[4].name_list == ['norm2.bias']
for plan in chunk_plans:
assert plan.chunk_size == 1088
assert plan.kwargs.get('shard_device') == gpu_device()
if __name__ == '__main__':
test_simple_search()

View File

@@ -0,0 +1,13 @@
from colossalai.elixir.simulator import move_count
from colossalai.testing import run_on_environment_flag
@run_on_environment_flag('ELX')
def test_move_count():
steps = [[0], [1, 2], [3], [3], [1, 2], [0]]
size = 2
assert move_count(steps, size) == 12
if __name__ == '__main__':
test_move_count()

View File

@@ -0,0 +1,23 @@
import pytest
import torch
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import to_cuda
@run_on_environment_flag('ELX')
def test_registry():
from tests.test_elixir.utils.registry import TEST_MODELS
for name, model_tuple in TEST_MODELS:
torch.cuda.synchronize()
print(f'model `{name}` is in testing')
model_fn, data_fn = model_tuple
model = model_fn().cuda()
data = to_cuda(data_fn())
loss = model(**data)
loss.backward()
if __name__ == '__main__':
test_registry()

View File

@@ -0,0 +1,46 @@
import pytest
import torch
from colossalai.elixir.tracer.memory_tracer import cuda_memory_profiling
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def one_step(model, inp):
loss = model(**inp)
loss.backward()
return loss
def try_one_model(model_fn, data_fn):
model = model_fn().cuda()
data = to_cuda(data_fn())
one_step(model, data) # generate gradients
pre_cuda_alc = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()
one_step(model, data)
aft_cuda_alc = torch.cuda.max_memory_allocated()
torch_activation_occ = aft_cuda_alc - pre_cuda_alc
model.zero_grad(set_to_none=True)
print('normal', torch_activation_occ)
before = torch.cuda.memory_allocated()
profiling_dict = cuda_memory_profiling(model, data, one_step)
after = torch.cuda.memory_allocated()
print('profiling', profiling_dict)
assert before == after
assert torch_activation_occ == profiling_dict['activation_occ']
print('Check is ok.')
@run_on_environment_flag('ELX')
def test_cuda_profiler():
model_list = ['resnet', 'gpt2_micro']
for name in model_list:
model_fn, data_fn = TEST_MODELS.get(name)
try_one_model(model_fn, data_fn)
if __name__ == '__main__':
test_cuda_profiler()

View File

@@ -0,0 +1,145 @@
import pytest
import torch
from colossalai.elixir.tracer.memory_tracer import MTensor
from colossalai.elixir.tracer.memory_tracer.op_cache import addmm_cache, bmm_cache, mm_cache
from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated
from colossalai.testing import run_on_environment_flag
def op_mm(x, y):
u = torch.matmul(x, y)
return u.shape
def op_addmm(x, y, z):
u = torch.addmm(x, y, z)
return u.shape
def op_bmm(x, y):
u = torch.bmm(x, y)
return u.shape
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
@run_on_environment_flag('ELX')
def test_mm(dtype, size0=(4, 256), size1=(256, 1024)):
torch.cuda.reset_peak_memory_stats()
assert get_cuda_allocated() == 0
x = torch.randn(size0, dtype=dtype, device='cuda')
y = torch.randn(size1, dtype=dtype, device='cuda')
torch_pre_alc = get_cuda_allocated()
torch_z_size = op_mm(x, y)
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
del x
del y
assert get_cuda_allocated() == 0
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
op1_pre_alc = get_cuda_allocated()
MTensor.reset_peak_memory()
op1_z_size = op_mm(x, y)
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op1_z_size
assert torch_pre_alc == op1_pre_alc
assert torch_temp_alc == op1_temp_alc
assert len(mm_cache.temp_memory) > 0
MTensor.reset_peak_memory()
op2_z_size = op_mm(x, y)
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op2_z_size
assert torch_temp_alc == op2_temp_alc
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
@run_on_environment_flag('ELX')
def test_addmm(dtype, size0=(4, 16), size1=(16, 64)):
torch.cuda.reset_peak_memory_stats()
assert get_cuda_allocated() == 0
x = torch.randn(size0, dtype=dtype, device='cuda')
y = torch.randn(size1, dtype=dtype, device='cuda')
u = torch.randn(size1[-1], dtype=dtype, device='cuda')
torch_pre_alc = get_cuda_allocated()
torch_z_size = op_addmm(u, x, y)
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
del x
del y
del u
assert get_cuda_allocated() == 0
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
u = MTensor(torch.randn(size1[-1], dtype=dtype, device='cuda'))
op1_pre_alc = get_cuda_allocated()
MTensor.reset_peak_memory()
op1_z_size = op_addmm(u, x, y)
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op1_z_size
assert torch_pre_alc == op1_pre_alc
assert torch_temp_alc == op1_temp_alc
assert len(addmm_cache.temp_memory) > 0
MTensor.reset_peak_memory()
op2_z_size = op_addmm(u, x, y)
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op2_z_size
assert torch_temp_alc == op2_temp_alc
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
@run_on_environment_flag('ELX')
def test_bmm(dtype, size0=(10, 4, 15), size1=(10, 15, 64)):
torch.cuda.reset_peak_memory_stats()
assert get_cuda_allocated() == 0
x = torch.randn(size0, dtype=dtype, device='cuda')
y = torch.randn(size1, dtype=dtype, device='cuda')
torch_pre_alc = get_cuda_allocated()
torch_z_size = op_bmm(x, y)
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
del x
del y
assert get_cuda_allocated() == 0
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
op1_pre_alc = get_cuda_allocated()
MTensor.reset_peak_memory()
op1_z_size = op_bmm(x, y)
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op1_z_size
assert torch_pre_alc == op1_pre_alc
assert torch_temp_alc == op1_temp_alc
assert len(bmm_cache.temp_memory) > 0
bmm_cache.print()
MTensor.reset_peak_memory()
op2_z_size = op_bmm(x, y)
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op2_z_size
assert torch_temp_alc == op2_temp_alc
if __name__ == '__main__':
test_addmm(dtype=torch.float)

View File

@@ -0,0 +1,37 @@
from colossalai.elixir.tracer.param_tracer import generate_tf_order
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
@run_on_environment_flag('ELX')
def test_tf_forward_backward():
model_fn, data_fn = TEST_MODELS.get('gpt2_micro')
model = model_fn()
data = data_fn()
def forward_backward_fn(local_model, local_input):
local_model(**local_input).backward()
# model.gradient_checkpointing_enable()
tf_order = generate_tf_order(model, data, forward_backward_fn)
params_per_step = tf_order['params_per_step']
assert len(params_per_step) == 32
model.gradient_checkpointing_enable()
tf_order = generate_tf_order(model, data, forward_backward_fn)
params_per_step = tf_order['params_per_step']
checkpoint_info = tf_order['checkpoint_info']
for i, step in enumerate(params_per_step):
print(f'step {i}: {step}')
for c in checkpoint_info:
print(f'checkpoint info: {c}')
assert len(params_per_step) == 44
assert data['input_ids'].device.type == 'cpu'
assert data['attention_mask'].device.type == 'cpu'
for param in model.parameters():
assert param.device.type == 'cpu'
if __name__ == '__main__':
test_tf_forward_backward()

View File

@@ -0,0 +1,94 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def amp_check_model_states(ddp_optim, test_model):
test_states = test_model.state_dict()
for (name, _), p in zip(test_model.module.named_parameters(), amp.master_params(ddp_optim)):
test_p = test_states[name]
copy_p = p.to(test_p.device)
print(f'checking parameter `{name}`: {test_p.dtype} {copy_p.dtype}')
assert_close(test_p.data, copy_p.data)
def exam_amp_one_model(model_fn, data_fn, nproc, group, exam_seed=2261):
ddp_model = model_fn().cuda()
test_model = copy.deepcopy(ddp_model)
# important here, since apex has a lazy fp32 init after the first optimizer step
test_model = test_model.half()
ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0)
ddp_model, ddp_optim = amp.initialize(ddp_model,
ddp_optim,
opt_level='O2',
loss_scale=1.0,
keep_batchnorm_fp32=False)
ddp_model = DDP(ddp_model, message_size=0, allreduce_always_fp32=True)
print("ok")
exit(0)
test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0)
sr = simple_search(test_model, nproc, shard_device=gpu_device(), unified_dtype=torch.float16, verbose=True)
test_model = ElixirModule(test_model, sr, group, dtype=torch.float16, reduce_always_fp32=True, output_fp32=True)
test_optim = ElixirOptimizer(test_model, test_optim, initial_scale=1.0)
# get different data
seed_all(exam_seed + dist.get_rank(group), cuda_deterministic=True)
for _ in range(2):
data = to_cuda(data_fn())
ddp_optim.zero_grad()
ddp_loss = ddp_model(**data)
with amp.scale_loss(ddp_loss, ddp_optim) as scaled_loss:
scaled_loss.backward()
ddp_optim.step()
test_optim.zero_grad()
test_loss = test_model(**data)
test_optim.backward(test_loss)
test_optim.step()
assert_close(ddp_loss, test_loss)
amp_check_model_states(ddp_optim, test_model)
def exam_amp_in_models(nproc, group):
model_fn, data_fn = TEST_MODELS.get('gpt2_micro')
exam_amp_one_model(model_fn, data_fn, nproc, group)
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_amp_in_models(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_elixir_amp(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_elixir_amp(world_size=2)

View File

@@ -0,0 +1,95 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.elixir.wrapper import ElixirModule
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, assert_dict_values, to_cuda
def check_gradient(ddp_model: nn.Module, test_model: ElixirModule):
grad_state = test_model.state_dict(from_param=True)
for name, param in ddp_model.named_parameters():
assert_close(param.grad.cpu(), grad_state[name])
def exam_module_init(nproc, group, grad_flag):
model_fn, data_fn = TEST_MODELS.get('resnet')
torch_model = model_fn().cuda()
test_model = model_fn().cuda()
for p1, p2 in zip(torch_model.parameters(), test_model.parameters()):
p1.requires_grad = p2.requires_grad = grad_flag
sr = simple_search(test_model, nproc)
model = ElixirModule(test_model, sr, group)
# check function: ElixirModule.load_state_dict after ElixirModule.__init__
torch_st = torch_model.state_dict()
if dist.get_rank() != 0:
torch_st = None
test_st = model.load_state_dict(torch_st, only_rank_0=True)
# check function: ElixirModule.state_dict after ElixirModule.__init__
torch_st = torch_model.state_dict()
test_st = model.state_dict()
assert_dict_values(torch_st, test_st, fn=torch.equal)
def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2261):
ddp_model = model_fn().cuda()
test_model = copy.deepcopy(ddp_model)
sr = simple_search(test_model, nproc, allocate_factor=0.6)
test_model = ElixirModule(test_model, sr, group)
# get different data
seed_all(exam_seed + dist.get_rank(group))
data = data_fn()
data = to_cuda(data)
seed_all(exam_seed, cuda_deterministic=True)
ddp_model = DDP(ddp_model)
ddp_loss = ddp_model(**data)
ddp_loss.backward()
test_loss = test_model(**data)
test_model.backward(test_loss)
assert_close(ddp_loss, test_loss)
check_gradient(ddp_model.module, test_model)
def exam_modules_fwd_bwd(nproc, group):
model_fn, data_fn = TEST_MODELS.get('resnet')
exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group)
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=False)
exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=True)
exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_elixir_module(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_elixir_module(world_size=2)

View File

@@ -0,0 +1,77 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, allclose, assert_dict_values, to_cuda
def exam_optimizer_one_model(model_fn, data_fn, nproc, group, exam_seed=2261):
ddp_model = model_fn().cuda()
test_model = copy.deepcopy(ddp_model)
ddp_model = DDP(ddp_model)
ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0)
test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0)
sr = simple_search(test_model, nproc, shard_device=gpu_device())
test_model = ElixirModule(test_model, sr, group)
test_optim = ElixirOptimizer(test_model, test_optim)
# get different data
seed_all(exam_seed + dist.get_rank(group))
data = to_cuda(data_fn())
seed_all(exam_seed, cuda_deterministic=True)
ddp_optim.zero_grad()
ddp_loss = ddp_model(**data)
ddp_loss.backward()
ddp_optim.step()
test_optim.zero_grad()
test_loss = test_model(**data)
test_optim.backward(test_loss)
test_optim.step()
assert_close(ddp_loss, test_loss)
torch_st = ddp_model.module.state_dict()
test_st = test_model.state_dict()
assert_dict_values(torch_st, test_st, fn=partial(allclose, rtol=2e-6, atol=2e-5))
def exam_optimizer_in_models(nproc, group):
model_fn, data_fn = TEST_MODELS.get('resnet')
exam_optimizer_one_model(model_fn, data_fn, nproc, group)
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_optimizer_in_models(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_elixir_optimizer(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_elixir_optimizer(world_size=4)

View File

@@ -0,0 +1,89 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.elixir.wrapper import ElixirModule
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def check_gradient(ddp_model: nn.Module, test_model: ElixirModule):
grad_state = test_model.state_dict(from_param=True)
for name, param in ddp_model.named_parameters():
assert_close(param.grad.cpu(), grad_state[name])
def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2263):
def one_step(local_model, local_input):
loss = local_model(**local_input)
loss.backward()
return loss
ddp_model = model_fn().cuda()
test_model = copy.deepcopy(ddp_model)
# get different data
seed_all(exam_seed + dist.get_rank(group))
data = to_cuda(data_fn())
# wrap as DDP model
ddp_model = DDP(ddp_model)
# search how to initialize chunks
sr = simple_search(test_model,
nproc,
shard_device=gpu_device(),
prefetch=True,
verbose=True,
inp=data,
step_fn=one_step)
test_model = ElixirModule(test_model, sr, group, prefetch=True)
seed_all(exam_seed, cuda_deterministic=True)
ddp_loss = one_step(ddp_model, data)
with torch.no_grad():
test_loss = test_model(**data)
assert_close(ddp_loss, test_loss)
test_loss = test_model(**data)
test_model.backward(test_loss)
assert_close(ddp_loss, test_loss)
check_gradient(ddp_model.module, test_model)
def exam_modules_fwd_bwd(nproc, group):
model_fn, data_fn = TEST_MODELS.get('resnet')
exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group)
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_module_prefetch(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_module_prefetch(world_size=2)

View File

@@ -0,0 +1,41 @@
import torch
from torch.testing import assert_close
from torch.utils._pytree import tree_map
from . import gpt, mlp, opt, resnet, small
from .registry import TEST_MODELS
def to_cuda(input_dict):
def local_fn(t):
if isinstance(t, torch.Tensor):
t = t.cuda()
return t
ret = tree_map(local_fn, input_dict)
return ret
def allclose(ta, tb, **kwargs):
assert_close(ta, tb, **kwargs)
return True
def assert_dict_keys(test_dict, keys):
assert len(test_dict) == len(keys)
for k in keys:
assert k in test_dict
def assert_dict_values(da, db, fn):
assert len(da) == len(db)
for k, v in da.items():
assert k in db
if not torch.is_tensor(v):
continue
u = db.get(k)
if u.device != v.device:
v = v.to(u.device)
# print(f"checking key {k}: {u.shape} vs {v.shape}")
assert fn(u.data, v.data), f'max diff {torch.max(torch.abs(u.data - v.data))}'

View File

@@ -0,0 +1,79 @@
from functools import partial
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from tests.test_elixir.utils.registry import TEST_MODELS
MICRO_VS = 128
MICRO_BS = 4
MICRO_SL = 64
MACRO_VS = 50257
MACRO_BS = 2
MACRO_SL = 1024
def micro_data_fn():
input_ids = torch.randint(low=0, high=MICRO_VS, size=(MICRO_BS, MICRO_SL))
attn_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attn_mask)
def small_data_fn():
input_ids = torch.randint(low=0, high=MACRO_VS, size=(MACRO_BS, MACRO_SL))
attn_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attn_mask)
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
class GPTLMModel(nn.Module):
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257):
super().__init__()
self.enable_gc = False
self.config = GPT2Config(
# pre-commit: do not rearrange
n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0)
self.module = GPT2LMHeadModel(config=self.config)
self.criterion = GPTLMLoss()
def gradient_checkpointing_enable(self):
self.module.gradient_checkpointing_enable()
self.enable_gc = True
def forward(self, input_ids, attention_mask):
# Only return lm_logits
output = self.module(input_ids=input_ids, attention_mask=attention_mask, use_cache=(not self.enable_gc))[0]
loss = self.criterion(output, input_ids)
return loss
gpt2_micro = partial(GPTLMModel, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128)
gpt2_small = GPTLMModel
gpt2_base = partial(GPTLMModel, hidden_size=1024, num_layers=24, num_attention_heads=16)
TEST_MODELS.register('gpt2_micro', gpt2_micro, micro_data_fn)
TEST_MODELS.register('gpt2_small', gpt2_small, small_data_fn)
TEST_MODELS.register('gpt2_base', gpt2_base, small_data_fn)

View File

@@ -0,0 +1,34 @@
import torch
import torch.nn as nn
from tests.test_elixir.utils.registry import TEST_MODELS
def mlp_data_fn():
return dict(x=torch.randn(4, 16))
class MlpModule(nn.Module):
def __init__(self, hidden_dim: int = 16) -> None:
super().__init__()
self.proj1 = nn.Linear(hidden_dim, 4 * hidden_dim)
self.act = nn.GELU()
self.proj2 = nn.Linear(4 * hidden_dim, hidden_dim)
def forward(self, x):
return x + (self.proj2(self.act(self.proj1(x))))
class MlpModel(nn.Module):
def __init__(self, hidden_dim: int = 16) -> None:
super().__init__()
self.mlp = MlpModule(hidden_dim)
def forward(self, x):
output = self.mlp(x)
return output.sum()
TEST_MODELS.register('mlp', MlpModel, mlp_data_fn)

View File

@@ -0,0 +1,46 @@
import torch.nn as nn
from transformers import OPTConfig, OPTForCausalLM
from tests.test_elixir.utils.registry import TEST_MODELS
from .gpt import micro_data_fn
class OPTLMModel(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.config = config
self.module = OPTForCausalLM(config=config)
self.enable_gc = False
def gradient_checkpointing_enable(self):
self.module.gradient_checkpointing_enable()
self.enable_gc = True
def forward(self, input_ids, attention_mask):
loss = self.module(
# pre-commit: do not rearrange
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
use_cache=(not self.enable_gc))['loss']
return loss
def opt_micro():
opt_config = OPTConfig(
# pre-commit: do not rearrange
vocab_size=128,
activation_dropout=0.0,
dropout=0,
hidden_size=32,
num_hidden_layers=4,
ffn_dim=128,
num_attention_heads=4,
word_embed_proj_dim=32,
output_projection=True)
return OPTLMModel(opt_config)
TEST_MODELS.register('opt_micro', opt_micro, micro_data_fn)

View File

@@ -0,0 +1,26 @@
from collections import OrderedDict
from typing import Callable
class Registry(object):
def __init__(self) -> None:
super().__init__()
self._registry_dict = OrderedDict()
def register(self, name: str, model_fn: Callable, data_fn: Callable):
assert name not in self._registry_dict
model_tuple = (model_fn, data_fn)
self._registry_dict[name] = model_tuple
def get(self, name: str):
return self._registry_dict[name]
def __iter__(self):
return iter(self._registry_dict.items())
TEST_MODELS = Registry()
__all__ = [TEST_MODELS]

View File

@@ -0,0 +1,23 @@
import torch
import torch.nn as nn
from torchvision.models import resnet18
from tests.test_elixir.utils.registry import TEST_MODELS
def resnet_data_fn():
return dict(x=torch.randn(4, 3, 32, 32))
class ResNetModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.r = resnet18()
def forward(self, x):
output = self.r(x)
return output.sum()
TEST_MODELS.register('resnet', ResNetModel, resnet_data_fn)

View File

@@ -0,0 +1,31 @@
import torch
import torch.nn as nn
from tests.test_elixir.utils.mlp import MlpModule
from tests.test_elixir.utils.registry import TEST_MODELS
def small_data_fn():
return dict(x=torch.randint(low=0, high=20, size=(4, 8)))
class SmallModel(nn.Module):
def __init__(self, num_embeddings: int = 20, hidden_dim: int = 16) -> None:
super().__init__()
self.embed = nn.Embedding(num_embeddings, hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.mlp = MlpModule(hidden_dim=hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim, num_embeddings, bias=False)
self.proj.weight = self.embed.weight
def forward(self, x):
x = self.embed(x)
x = x + self.norm1(self.mlp(x))
x = self.proj(self.norm2(x))
x = x.mean(dim=-2)
return x.sum()
TEST_MODELS.register('small', SmallModel, small_data_fn)