[zero] zero optim state_dict takes only_rank_0 (#1384)

* zero optim state_dict takes only_rank_0

* fix unit test
This commit is contained in:
ver217
2022-07-29 13:22:50 +08:00
committed by GitHub
parent 7d5d628e07
commit 8dced41ad0
2 changed files with 18 additions and 13 deletions

View File

@@ -45,7 +45,8 @@ def check_state_dict(state_dict, torch_state_dict):
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
@parameterize('only_rank_0', [False, True])
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy, only_rank_0):
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@@ -76,8 +77,8 @@ def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
optim.load_state_dict(torch_state_dict)
check_load_state_dict(optim, torch_optim)
state_dict = optim.state_dict()
if pg.rank() == 0:
state_dict = optim.state_dict(only_rank_0)
if not only_rank_0 or pg.rank() == 0:
check_state_dict(state_dict, torch_state_dict)