mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)
* [checkpointio] unsharded optimizer checkpoint for Gemini plugin * [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
This commit is contained in:
@@ -5,6 +5,7 @@ import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.testing import assert_close
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
def assert_equal(a: Tensor, b: Tensor):
|
||||
@@ -16,7 +17,12 @@ def assert_not_equal(a: Tensor, b: Tensor):
|
||||
|
||||
|
||||
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
assert_close(a,
|
||||
b,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
|
||||
dtype: {a.dtype} vs {b.dtype}")
|
||||
|
||||
|
||||
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
||||
@@ -33,25 +39,51 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
||||
|
||||
|
||||
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||
for k, v in d1.items():
|
||||
if isinstance(v, dict):
|
||||
check_state_dict_equal(v, d2[k])
|
||||
elif isinstance(v, list):
|
||||
for i in range(len(v)):
|
||||
if isinstance(v[i], torch.Tensor):
|
||||
assert len(list(d1.keys())) == len(list(d2.keys())), \
|
||||
f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
|
||||
for k, v1 in d1.items():
|
||||
assert k in d2
|
||||
v2 = d2[k]
|
||||
if isinstance(v1, dict):
|
||||
assert isinstance(v2, dict)
|
||||
check_state_dict_equal(v1, v2, ignore_device)
|
||||
elif isinstance(v1, list):
|
||||
assert isinstance(v2, list)
|
||||
for v1_i, v2_i in zip(v1, v2):
|
||||
if isinstance(v1_i, torch.Tensor):
|
||||
assert isinstance(v2_i, torch.Tensor)
|
||||
if not ignore_device:
|
||||
v[i] = v[i].to("cpu")
|
||||
d2[k][i] = d2[k][i].to("cpu")
|
||||
assert torch.equal(v[i], d2[k][i])
|
||||
v1_i = v1_i.to("cpu")
|
||||
v2_i = v2_i.to("cpu")
|
||||
assert_close_loose(v1_i, v2_i)
|
||||
elif isinstance(v1_i, dict):
|
||||
assert isinstance(v2_i, dict)
|
||||
check_state_dict_equal(v1_i, v2_i, ignore_device)
|
||||
else:
|
||||
assert v[i] == d2[k][i]
|
||||
elif isinstance(v, torch.Tensor):
|
||||
assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}"
|
||||
elif isinstance(v1, torch.Tensor):
|
||||
assert isinstance(v2, torch.Tensor)
|
||||
if not ignore_device:
|
||||
v = v.to("cpu")
|
||||
d2[k] = d2[k].to("cpu")
|
||||
assert torch.equal(v, d2[k])
|
||||
v1 = v1.to("cpu")
|
||||
v2 = v2.to("cpu")
|
||||
assert_close_loose(v1, v2)
|
||||
else:
|
||||
assert v == d2[k]
|
||||
assert v1 == v2, f"{v1} not equals to {v2}"
|
||||
|
||||
|
||||
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||
flat_d1, _ = tree_flatten(d1)
|
||||
flat_d2, _ = tree_flatten(d2)
|
||||
assert len(flat_d1) == len(flat_d2)
|
||||
for v1, v2 in zip(flat_d1, flat_d2):
|
||||
if isinstance(v1, torch.Tensor):
|
||||
assert isinstance(v2, torch.Tensor)
|
||||
if not ignore_device:
|
||||
v1 = v1.to("cpu")
|
||||
v2 = v2.to("cpu")
|
||||
assert_close_loose(v1, v2)
|
||||
else:
|
||||
assert v1 == v2, f"{v1} not equals to {v2}"
|
||||
|
||||
|
||||
def assert_hf_output_close(out1: Any,
|
||||
|
Reference in New Issue
Block a user