mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[Gemini] add unitests to check gemini correctness (#2015)
This commit is contained in:
@@ -37,9 +37,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids):
|
||||
optimizer.zero_grad()
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = model(input_ids)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optimizer.backward(loss)
|
||||
@@ -83,12 +83,12 @@ def exam_gpt_fwd_bwd(placement_policy):
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
|
||||
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
||||
# debug_print([0], zero_logits, torch_logits)
|
||||
|
||||
@@ -127,12 +127,12 @@ def exam_tiny_example(placement_policy):
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
|
||||
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
||||
# debug_print([0], zero_logits, torch_logits)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user