[gemini] support amp o3 for gemini (#4872)

* [gemini] support no reuse fp16 chunk

* [gemini] support no master weight for optim

* [gemini] support no master weight for gemini ddp

* [test] update gemini tests

* [test] update gemini tests

* [plugin] update gemini plugin

* [test] fix gemini checkpointio test

* [test] fix gemini checkpoint io
This commit is contained in:
Hongxin Liu
2023-10-12 10:39:08 +08:00
committed by GitHub
parent c1fab951e7
commit df63564184
15 changed files with 222 additions and 114 deletions

View File

@@ -40,7 +40,7 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False):
assert len(list(d1.keys())) == len(
list(d2.keys())
), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
@@ -58,6 +58,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
if not ignore_device:
v1_i = v1_i.to("cpu")
v2_i = v2_i.to("cpu")
if ignore_dtype:
v1_i = v1_i.to(v2_i.dtype)
assert_close_loose(v1_i, v2_i)
elif isinstance(v1_i, dict):
assert isinstance(v2_i, dict)
@@ -69,6 +71,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
if not ignore_device:
v1 = v1.to("cpu")
v2 = v2.to("cpu")
if ignore_dtype:
v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2)
else:
assert v1 == v2, f"{v1} not equals to {v2}"