mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[test] Hotfix/fix some model test and refactor check util api (#4369)
* fix llama test * fix test bug of bert, blip2, bloom, gpt2 * fix llama test * fix opt test * fix sam test * fix sam test * fix t5 test * fix vit test * fix whisper test * fix whisper test * polish code * adjust allclose parameter * Add mistakenly deleted code * addjust allclose * change loss function for some base model
This commit is contained in:
@@ -2,10 +2,13 @@ import copy
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer._utils import getattr_
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
|
||||
|
||||
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
|
||||
@@ -74,3 +77,22 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
|
||||
assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
|
||||
assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
|
||||
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
|
||||
|
||||
|
||||
def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False):
|
||||
for suffix in layer_suffix:
|
||||
org_grad = getattr_(original_model, suffix).weight.grad
|
||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||
shard_weight = getattr_(sharded_model, suffix).weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=dim)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}")
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad, rtol=rtol, atol=atol
|
||||
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
Reference in New Issue
Block a user