mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[shardformer] fix opt test hanging (#4521)
* [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix
This commit is contained in:
@@ -237,6 +237,43 @@ def check_weight(org_model: Module,
|
||||
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
|
||||
|
||||
|
||||
def get_grad_tensors_for_check(org_model: Module,
|
||||
sharded_model: Module,
|
||||
layer_suffix: List[str],
|
||||
tp_group: ProcessGroup = None,
|
||||
dim: int = 0,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
verbose: bool = False,
|
||||
name: str = None):
|
||||
|
||||
grad_to_check = {}
|
||||
for suffix in layer_suffix:
|
||||
org_grad = getattr_(org_model, suffix).weight.grad
|
||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||
shard_weight = getattr_(sharded_model, suffix).weight
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
|
||||
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
||||
|
||||
# embedding may be resized when using tensor parallel
|
||||
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||
shard_grad = shard_grad[:org_grad.shape[0], :]
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||
|
||||
grad_to_check[suffix] = {
|
||||
"org_grad": org_grad.float(),
|
||||
"shard_grad": shard_grad.float(),
|
||||
"rtol": rtol,
|
||||
"atol": atol
|
||||
}
|
||||
|
||||
return grad_to_check
|
||||
|
||||
|
||||
# used by sam/blip2
|
||||
def check_grad(org_model: Module,
|
||||
sharded_model: Module,
|
||||
layer_suffix: List[str],
|
||||
@@ -275,3 +312,18 @@ def unwrap_model(module: Module,
|
||||
if module.__class__.__name__ == base_model_class_name:
|
||||
return module
|
||||
return getattr(module, base_model_attribute_name, None)
|
||||
|
||||
|
||||
def check_all_grad_tensors(check_tensors):
|
||||
"""
|
||||
"org_grad": tensor to be compared from the original model
|
||||
"shard_grad": tensor to be compared from the sharded model
|
||||
"""
|
||||
for suffix, check_info in check_tensors.items():
|
||||
org_grad = check_info["org_grad"]
|
||||
shard_grad = check_info["shard_grad"]
|
||||
rtol = check_info["rtol"]
|
||||
atol = check_info["atol"]
|
||||
assert torch.allclose(
|
||||
org_grad, shard_grad, atol=atol, rtol=rtol
|
||||
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
|
||||
|
Reference in New Issue
Block a user