[zero] fix unit-tests (#2039)

This commit is contained in:
HELSON 2022-11-30 10:40:31 +08:00 committed by GitHub
parent eb7742a4bb
commit 17a3c685b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 44 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tensor: def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
"""run_fwd_bwd """run_fwd_bwd
run fwd and bwd for the model run fwd and bwd for the model
@ -10,7 +10,6 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
data (torch.Tensor): input data data (torch.Tensor): input data
label (torch.Tensor): label label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion criterion (Optional[Callable]): a function of criterion
use_init_ctx (bool, optional): whether the model is initialized under the contxt of ColoInitCtx. Defaults to False.
Returns: Returns:
torch.Tensor: loss of fwd torch.Tensor: loss of fwd
@ -23,8 +22,8 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
loss = model(data, label) loss = model(data, label)
loss = loss.float() loss = loss.float()
if use_init_ctx: if optimizer:
model.backward(loss) optimizer.backward(loss)
else: else:
loss.backward() loss.backward()
return loss return loss

View File

@ -33,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True):
data = data.cuda() data = data.cuda()
label = label.cuda() label = label.cuda()
run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) run_fwd_bwd(model, data, label, criterion)
model._ophook_list[0].print_non_model_data() model._ophook_list[0].print_non_model_data()

View File

@ -10,6 +10,8 @@ import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
@ -55,6 +57,8 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True) model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
pg = ProcessGroup() pg = ProcessGroup()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
@ -71,9 +75,9 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
# after bwd param is grad for Gemini, due to the chunk reuse optimization. # after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 0: if i > 0:
break break
input_ids, label = input_ids.cuda(), label.cuda()
torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False) torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert torch.equal(torch_loss, loss) assert torch.equal(torch_loss, loss)

View File

@ -6,6 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
@ -20,7 +21,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test import run_fwd_bwd from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed from tests.test_tensor.common_utils import debug_print, set_seed
def check_param(model: ZeroDDP, torch_model: torch.nn.Module): def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
@ -35,27 +36,31 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
# 'gpt2', 'bert', # 'gpt2', 'bert',
TEST_MODELS = ['gpt2', 'bert'] TEST_MODELS = ['gpt2', 'bert']
# TEST_MODELS = ['simple_net'] EXAMPLE_MODELS = ['simple_net']
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_policy', ['cuda'])
@parameterize('model_name', TEST_MODELS) @parameterize('model_name', TEST_MODELS)
def exam_model_step(placement_policy, model_name: str): def exam_model_step(placement_policy, model_name: str):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder() model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data) p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
@ -70,12 +75,7 @@ def exam_model_step(placement_policy, model_name: str):
model = ZeroDDP(model, gemini_manager, pin_memory=True) model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval() model.eval()
torch_model.eval() torch_model.eval()
@ -84,15 +84,13 @@ def exam_model_step(placement_policy, model_name: str):
for i, (input_ids, label) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
input_ids, label = input_ids.cuda(), label.cuda()
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False) torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss)
assert torch.allclose(torch_loss, loss, rtol=1e-3, atol=1e-2), f"{torch_loss} vs {loss}"
# debug_print([0], zero_logits, torch_logits)
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
@ -101,31 +99,29 @@ def exam_model_step(placement_policy, model_name: str):
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', TEST_MODELS) @parameterize('model_name', EXAMPLE_MODELS)
def exam_tiny_example(placement_policy, model_name: str): def exam_tiny_example(placement_policy, model_name: str):
set_seed(42) set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder() model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data) p.data.copy_(torch_p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True) model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval() model.eval()
torch_model.eval() torch_model.eval()
@ -134,14 +130,15 @@ def exam_tiny_example(placement_policy, model_name: str):
if i > 2: if i > 2:
break break
input_ids = input_ids.cuda()
label = label.cuda()
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False) torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss)
assert torch.allclose(torch_loss, loss, rtol=1e-3, atol=1e-2), f"{torch_loss} vs {loss}"
# debug_print([0], zero_logits, torch_logits)
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
@ -165,4 +162,4 @@ def test_optim(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_optim(2) test_optim(1)