mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[booster] add tests for ddp and low level zero's checkpointio (#3715)
* [booster] update tests for booster * [booster] update tests for booster * [booster] update tests for booster * [booster] update tests for booster * [booster] update tests for booster * [booster] update booster tutorials#3717, fix recursive check
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
@@ -28,3 +30,25 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
||||
a = tensor_list[i]
|
||||
b = tensor_list[i + 1]
|
||||
assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}'
|
||||
|
||||
|
||||
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||
for k, v in d1.items():
|
||||
if isinstance(v, dict):
|
||||
check_state_dict_equal(v, d2[k])
|
||||
elif isinstance(v, list):
|
||||
for i in range(len(v)):
|
||||
if isinstance(v[i], torch.Tensor):
|
||||
if not ignore_device:
|
||||
v[i] = v[i].to("cpu")
|
||||
d2[k][i] = d2[k][i].to("cpu")
|
||||
assert torch.equal(v[i], d2[k][i])
|
||||
else:
|
||||
assert v[i] == d2[k][i]
|
||||
elif isinstance(v, torch.Tensor):
|
||||
if not ignore_device:
|
||||
v = v.to("cpu")
|
||||
d2[k] = d2[k].to("cpu")
|
||||
assert torch.equal(v, d2[k])
|
||||
else:
|
||||
assert v == d2[k]
|
||||
|
Reference in New Issue
Block a user