Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)

* [checkpointio] unsharded optimizer checkpoint for Gemini plugin

* [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
This commit is contained in:
Baizhou Zhang
2023-07-07 16:33:06 +08:00
committed by GitHub
parent fee32a3b78
commit 58913441a1
9 changed files with 684 additions and 83 deletions

View File

@@ -5,6 +5,7 @@ import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten
def assert_equal(a: Tensor, b: Tensor):
@@ -16,7 +17,12 @@ def assert_not_equal(a: Tensor, b: Tensor):
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
assert_close(a, b, rtol=rtol, atol=atol)
assert_close(a,
b,
rtol=rtol,
atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
dtype: {a.dtype} vs {b.dtype}")
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
@@ -33,25 +39,51 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
for k, v in d1.items():
if isinstance(v, dict):
check_state_dict_equal(v, d2[k])
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], torch.Tensor):
assert len(list(d1.keys())) == len(list(d2.keys())), \
f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
for k, v1 in d1.items():
assert k in d2
v2 = d2[k]
if isinstance(v1, dict):
assert isinstance(v2, dict)
check_state_dict_equal(v1, v2, ignore_device)
elif isinstance(v1, list):
assert isinstance(v2, list)
for v1_i, v2_i in zip(v1, v2):
if isinstance(v1_i, torch.Tensor):
assert isinstance(v2_i, torch.Tensor)
if not ignore_device:
v[i] = v[i].to("cpu")
d2[k][i] = d2[k][i].to("cpu")
assert torch.equal(v[i], d2[k][i])
v1_i = v1_i.to("cpu")
v2_i = v2_i.to("cpu")
assert_close_loose(v1_i, v2_i)
elif isinstance(v1_i, dict):
assert isinstance(v2_i, dict)
check_state_dict_equal(v1_i, v2_i, ignore_device)
else:
assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor):
assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}"
elif isinstance(v1, torch.Tensor):
assert isinstance(v2, torch.Tensor)
if not ignore_device:
v = v.to("cpu")
d2[k] = d2[k].to("cpu")
assert torch.equal(v, d2[k])
v1 = v1.to("cpu")
v2 = v2.to("cpu")
assert_close_loose(v1, v2)
else:
assert v == d2[k]
assert v1 == v2, f"{v1} not equals to {v2}"
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
flat_d1, _ = tree_flatten(d1)
flat_d2, _ = tree_flatten(d2)
assert len(flat_d1) == len(flat_d2)
for v1, v2 in zip(flat_d1, flat_d2):
if isinstance(v1, torch.Tensor):
assert isinstance(v2, torch.Tensor)
if not ignore_device:
v1 = v1.to("cpu")
v2 = v2.to("cpu")
assert_close_loose(v1, v2)
else:
assert v1 == v2, f"{v1} not equals to {v2}"
def assert_hf_output_close(out1: Any,