[shardformer] fix opt test hanging (#4521)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix
This commit is contained in:
flybird11111
2023-08-30 14:50:34 +08:00
committed by GitHub
parent e241b74f24
commit d367b88785
13 changed files with 472 additions and 249 deletions

View File

@@ -237,6 +237,43 @@ def check_weight(org_model: Module,
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
def get_grad_tensors_for_check(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
name: str = None):
grad_to_check = {}
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
grad_to_check[suffix] = {
"org_grad": org_grad.float(),
"shard_grad": shard_grad.float(),
"rtol": rtol,
"atol": atol
}
return grad_to_check
# used by sam/blip2
def check_grad(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
@@ -275,3 +312,18 @@ def unwrap_model(module: Module,
if module.__class__.__name__ == base_model_class_name:
return module
return getattr(module, base_model_attribute_name, None)
def check_all_grad_tensors(check_tensors):
"""
"org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model
"""
for suffix, check_info in check_tensors.items():
org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"]
atol = check_info["atol"]
assert torch.allclose(
org_grad, shard_grad, atol=atol, rtol=rtol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"