[zero] refactor low level zero for shard evenly (#4030)

* refactor low level zero

* fix zero2 and support cpu offload

* avg gradient and modify unit test

* refactor grad store, support layer drop

* refactor bucket store, support grad accumulation

* fix and update unit test of zero and ddp

* compatible with tp, ga and unit test

* fix memory leak and polish

* add zero layer drop unittest

* polish code

* fix import err in unit test

* support diffenert comm dtype, modify docstring style

* polish code

* test padding and fix

* fix unit test of low level zero

* fix pad recording in bucket store

* support some models

* polish
This commit is contained in:
LuGY
2023-06-30 15:30:50 +08:00
committed by Hongxin Liu
parent 5187c96b7c
commit c6ab96983a
8 changed files with 424 additions and 470 deletions

View File

@@ -253,7 +253,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
return total_norm
def sync_param(flat_tensor, tensor_list):
def sync_tensor(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,