[booster] add tests for ddp and low level zero's checkpointio (#3715)

* [booster] update tests for booster

* [booster] update tests for booster

* [booster] update tests for booster

* [booster] update tests for booster

* [booster] update tests for booster

* [booster] update booster tutorials#3717, fix recursive check
This commit is contained in:
jiangmingyan
2023-05-10 12:17:02 +08:00
committed by GitHub
parent 6552cbf8e1
commit 20068ba188
6 changed files with 261 additions and 125 deletions

View File

@@ -1,3 +1,5 @@
from typing import OrderedDict
import torch
import torch.distributed as dist
from torch import Tensor
@@ -28,3 +30,25 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
a = tensor_list[i]
b = tensor_list[i + 1]
assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}'
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
for k, v in d1.items():
if isinstance(v, dict):
check_state_dict_equal(v, d2[k])
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], torch.Tensor):
if not ignore_device:
v[i] = v[i].to("cpu")
d2[k][i] = d2[k][i].to("cpu")
assert torch.equal(v[i], d2[k][i])
else:
assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor):
if not ignore_device:
v = v.to("cpu")
d2[k] = d2[k].to("cpu")
assert torch.equal(v, d2[k])
else:
assert v == d2[k]