[hotfix] add bert test for gemini fwd bwd (#2035)

This commit is contained in:
Jiarui Fang
2022-11-29 11:19:52 +08:00
committed by GitHub
parent 0dbcd4a6f5
commit 96134e7be3
3 changed files with 11 additions and 13 deletions

View File

@@ -33,7 +33,7 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'resnet18'])
@parameterize('model_name', ['gpt2', 'bert'])
@parameterize('use_grad_checkpoint', [False, True])
def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
set_seed(42)
@@ -78,7 +78,7 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
torch.max(torch.abs(loss - torch_loss)).item(), loss, torch_loss)
# FIXME(1SAA) bert and resnet18 can not pass the check_grad
# check_grad(model, torch_model)
check_grad(model, torch_model)
def run_dist(rank, world_size, port):