diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 2b5b120b1..585f6f565 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -95,6 +95,15 @@ def run_1d_hybrid_tp(model_name): set_seed(1) with ColoInitContext(device=get_current_device()): model = model_builder(checkpoint=True) + + if rank == 0: + model_torch = model_builder(checkpoint=True) + model_torch = model_torch.cuda() + colo_optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1) + + # Make two models have the same init params + for p1, p2 in zip(model.parameters(), model_torch.parameters()): + p2.data.copy_(p1.data) if 'bert' == model_name: parallel_action_list_row = [ @@ -176,14 +185,15 @@ def run_1d_hybrid_tp(model_name): if 'classifier' in name and ('weight' in name or 'bias' in name): p.set_spec(spec_classifier_col) - set_seed(1) - if rank == 0: - model_torch = model_builder(checkpoint=True) - model_torch = model_torch.cuda() - model = model.cuda() - + colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) for i, (data, label) in enumerate(train_dataloader): + model.eval() + colo_optimizer.zero_grad() + if rank == 0: + model_torch.eval() + colo_optimizer_torch.zero_grad() + data = data.to(get_current_device()) label = label.to(get_current_device()) @@ -210,12 +220,33 @@ def run_1d_hybrid_tp(model_name): if rank == 0: # print(loss.torch_tensor().item()) # print('loss torch', loss_torch.item()) - assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2) + with torch.no_grad(): + assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2) loss.backward() + colo_optimizer.step() if rank == 0: loss_torch.backward() + colo_optimizer_torch.step() + + with torch.no_grad(): + # check param + for p1, p2 in zip(model.parameters(), model_torch.parameters()): + if p1.size() == p2.size(): + assert torch.allclose(p1, p2) + else: + # TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec. + if p1.size(-1) < p2.size(-1): # col + world_size = p2.size(-1) // p1.size(-1) + split_p2 = torch.chunk(p2, world_size, dim=-1)[0] + + elif p1.size(0) < p2.size(0): # row + world_size = p2.size(0) // p1.size(0) + split_p2 = torch.chunk(p2, world_size, dim=0)[0] + + assert torch.allclose(p1, split_p2) + if i > 5: break @@ -428,5 +459,5 @@ def _test_pretrain_load(world_size): if __name__ == '__main__': # test_model_parameters() # test_colo_optimizer() - # test_model() - _test_pretrain_load(4) + test_model(4) + # _test_pretrain_load(4)