[test] Hotfix/fix some model test and refactor check util api (#4369)

* fix llama test

* fix test bug of bert, blip2, bloom, gpt2

* fix llama test

* fix opt test

* fix sam test

* fix sam test

* fix t5 test

* fix vit test

* fix whisper test

* fix whisper test

* polish code

* adjust allclose parameter

* Add mistakenly deleted code

* addjust allclose

* change loss function for some base model
This commit is contained in:
Bin Jia
2023-08-03 14:51:36 +08:00
committed by Hongxin Liu
parent c3ca53cf05
commit 5c6f183192
16 changed files with 135 additions and 336 deletions

View File

@@ -2,10 +2,13 @@ import copy
from contextlib import nullcontext
import torch
import torch.distributed as dist
from torch.nn import Module
from colossalai.lazy import LazyInitContext
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
@@ -74,3 +77,22 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False):
for suffix in layer_suffix:
org_grad = getattr_(original_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=dim)
else:
all_shard_grad = shard_grad
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}")
assert torch.allclose(
org_grad, all_shard_grad, rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}"