mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
@@ -1,124 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context.random import add_seed, reset_seeds, seed, set_mode
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from colossalai.utils.activation_checkpoint import checkpoint
|
||||
|
||||
|
||||
def forward(x, weight):
|
||||
out = torch.matmul(x, weight)
|
||||
with seed(ParallelMode.DATA):
|
||||
out_ = F.dropout(out, p=0.4, training=True)
|
||||
return out_
|
||||
|
||||
|
||||
def forward_inplace_ckpt(x, weight, cpu_offload=False):
|
||||
out = torch.matmul(x, weight)
|
||||
bn = torch.nn.BatchNorm1d(4, affine=False)
|
||||
bn = bn.to(device="cuda")
|
||||
out = bn(out)
|
||||
|
||||
def ckpt0(x):
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False)
|
||||
return out
|
||||
|
||||
|
||||
def forward_inplace(x, weight):
|
||||
out = torch.matmul(x, weight)
|
||||
bn = torch.nn.BatchNorm1d(4, affine=False)
|
||||
bn = bn.to(device="cuda")
|
||||
out = bn(out)
|
||||
out = F.relu(out, inplace=True)
|
||||
return out
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("use_reentrant", [True, False])
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
def test_activation_checkpointing(cpu_offload, use_reentrant):
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests might affect this test
|
||||
reset_seeds()
|
||||
|
||||
# We put initialization here to avoid change cuda rng state below
|
||||
inputs = torch.rand(2, 2, requires_grad=True, device='cuda')
|
||||
weight = torch.rand(2, 4, requires_grad=True, device='cuda')
|
||||
|
||||
# Get a copy of input tensors
|
||||
inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda')
|
||||
inputs_.data.copy_(inputs.data)
|
||||
weight_ = torch.empty(2, 4, requires_grad=True, device='cuda')
|
||||
weight_.data.copy_(weight.data)
|
||||
|
||||
add_seed(ParallelMode.GLOBAL, 1024)
|
||||
add_seed(ParallelMode.DATA, 1026)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
global_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
set_mode(ParallelMode.DATA)
|
||||
data_parallel_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = forward(inputs, weight)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
# Recover cuda rng states
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
torch.cuda.set_rng_state(global_cuda_rng_state)
|
||||
set_mode(ParallelMode.DATA)
|
||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Extra test for use_reentrant=False
|
||||
if use_reentrant == False:
|
||||
# Recover cuda rng states
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
torch.cuda.set_rng_state(global_cuda_rng_state)
|
||||
set_mode(ParallelMode.DATA)
|
||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = forward_inplace(inputs, weight)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
# Recover cuda rng states
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
torch.cuda.set_rng_state(global_cuda_rng_state)
|
||||
set_mode(ParallelMode.DATA)
|
||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests will fail if running together with this test
|
||||
# as other tests can't overwrite the seed set by this test
|
||||
reset_seeds()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_activation_checkpointing(False, False)
|
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.legacy.nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_1d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint_1d():
|
||||
spawn(check_checkpoint_1d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_1d()
|
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.legacy.nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_2d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint_2d():
|
||||
spawn(check_checkpoint_2d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_2d()
|
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.legacy.nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_2p5d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint_2p5d():
|
||||
spawn(check_checkpoint_2p5d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_2p5d()
|
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.legacy.nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_3d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint_3d():
|
||||
spawn(check_checkpoint_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_3d()
|
@@ -1,41 +0,0 @@
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.legacy.sharded_param import ShardedTensor
|
||||
|
||||
|
||||
def run_tensor_move(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
src_t = torch.ones(2, 3).cuda()
|
||||
tgt_t = torch.zeros(2, 3)
|
||||
|
||||
colo_model_data_tensor_move(src_t, tgt_t)
|
||||
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||
|
||||
src_t = torch.ones(2, 3)
|
||||
tgt_t = torch.zeros(2, 3).cuda().half()
|
||||
colo_model_data_tensor_move(src_t, tgt_t)
|
||||
# the src_t has been removed
|
||||
assert (src_t.numel() == 0)
|
||||
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||
|
||||
src_t = ShardedTensor(torch.ones(2, 3))
|
||||
tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half())
|
||||
colo_model_data_tensor_move(src_t, tgt_t)
|
||||
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||
|
||||
assert (tgt_t.device.type == 'cuda')
|
||||
colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu'))
|
||||
assert (tgt_t.device.type == 'cpu')
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_tensor_move():
|
||||
spawn(run_tensor_move, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_tensor_move()
|
@@ -1,28 +0,0 @@
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
|
||||
|
||||
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
|
||||
frac1 = colo_device_memory_capacity(get_current_device())
|
||||
colo_set_process_memory_fraction(0.5)
|
||||
frac2 = colo_device_memory_capacity(get_current_device())
|
||||
assert frac2 * 2 == frac1
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_colo_set_process_memory_fraction_and_colo_device_memory_capacity()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [3, 4])
|
||||
def test_memory_utils(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_memory_utils(world_size=2)
|
@@ -1,78 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.common import clip_grad_norm
|
||||
|
||||
|
||||
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
|
||||
return abs(num - other) <= atol + rtol * other
|
||||
|
||||
|
||||
def shard_param(p: ColoParameter) -> None:
|
||||
pg = p.get_process_group()
|
||||
p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()]))
|
||||
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()
|
||||
|
||||
|
||||
def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None:
|
||||
pg = colo_p.get_process_group()
|
||||
if p.shape != colo_p.shape:
|
||||
grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()]
|
||||
else:
|
||||
grad = p.grad
|
||||
assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}'
|
||||
|
||||
|
||||
@parameterize('dtype', [torch.float])
|
||||
@parameterize('device', ['mixed', 'cuda', 'cpu'])
|
||||
@parameterize('norm_type', [2.0, 3.0, float('inf')])
|
||||
def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float):
|
||||
print(f'{world_size}, {dtype}, {device}, {norm_type}')
|
||||
cuda_device = get_current_device()
|
||||
devices = [cuda_device] * 4
|
||||
if device == 'cpu':
|
||||
devices = [torch.device('cpu')] * 4
|
||||
elif device == 'mixed':
|
||||
devices = [cuda_device] * 2 + [torch.device('cpu')] * 2
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)]
|
||||
colo_params = [
|
||||
ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4)
|
||||
]
|
||||
for p, colo_p in zip(params, colo_params):
|
||||
grad = torch.rand_like(p)
|
||||
p.grad = grad
|
||||
colo_p.grad = grad.clone().detach()
|
||||
shard_param(colo_params[0])
|
||||
shard_param(colo_params[2])
|
||||
torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type)
|
||||
colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type)
|
||||
assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}'
|
||||
for p, colo_p in zip(params, colo_params):
|
||||
check_grad_equal(p, colo_p)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_grad_clip_norm(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.skip("this need to be updated")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_clip_grad(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_clip_grad(2)
|
@@ -1,111 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import checkpoint, clip_grad_norm_fp32
|
||||
from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward, False)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
clip_grad(model, norm_type)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def clip_grad(model, norm_type):
|
||||
if isinstance(model, DDP):
|
||||
clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
|
||||
else:
|
||||
clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def check_grads(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank]
|
||||
if zero_p.zero_shard_padding > 0:
|
||||
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_shard_padding = zero_p.zero_shard_padding
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_shard_padding > 0:
|
||||
zero_p = zero_p[:-zero_shard_padding]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_clip_grad():
|
||||
world_size = 4
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_clip_grad()
|
Reference in New Issue
Block a user