mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[zero] zero optim state_dict takes only_rank_0 (#1384)
* zero optim state_dict takes only_rank_0 * fix unit test
This commit is contained in:
@@ -45,7 +45,8 @@ def check_state_dict(state_dict, torch_state_dict):
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@parameterize('use_zero', [False, True])
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
|
||||
@parameterize('only_rank_0', [False, True])
|
||||
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy, only_rank_0):
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
@@ -76,8 +77,8 @@ def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
|
||||
optim.load_state_dict(torch_state_dict)
|
||||
check_load_state_dict(optim, torch_optim)
|
||||
|
||||
state_dict = optim.state_dict()
|
||||
if pg.rank() == 0:
|
||||
state_dict = optim.state_dict(only_rank_0)
|
||||
if not only_rank_0 or pg.rank() == 0:
|
||||
check_state_dict(state_dict, torch_state_dict)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user