mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[gemini] support amp o3 for gemini (#4872)
* [gemini] support no reuse fp16 chunk * [gemini] support no master weight for optim * [gemini] support no master weight for gemini ddp * [test] update gemini tests * [test] update gemini tests * [plugin] update gemini plugin * [test] fix gemini checkpointio test * [test] fix gemini checkpoint io
This commit is contained in:
@@ -40,7 +40,7 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
||||
assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"
|
||||
|
||||
|
||||
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False):
|
||||
assert len(list(d1.keys())) == len(
|
||||
list(d2.keys())
|
||||
), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
|
||||
@@ -58,6 +58,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
|
||||
if not ignore_device:
|
||||
v1_i = v1_i.to("cpu")
|
||||
v2_i = v2_i.to("cpu")
|
||||
if ignore_dtype:
|
||||
v1_i = v1_i.to(v2_i.dtype)
|
||||
assert_close_loose(v1_i, v2_i)
|
||||
elif isinstance(v1_i, dict):
|
||||
assert isinstance(v2_i, dict)
|
||||
@@ -69,6 +71,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
|
||||
if not ignore_device:
|
||||
v1 = v1.to("cpu")
|
||||
v2 = v2.to("cpu")
|
||||
if ignore_dtype:
|
||||
v1 = v1.to(v2.dtype)
|
||||
assert_close_loose(v1, v2)
|
||||
else:
|
||||
assert v1 == v2, f"{v1} not equals to {v2}"
|
||||
|
Reference in New Issue
Block a user