mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-11 21:01:54 +00:00
[checkpoint] save sharded optimizer states (#1237)
This commit is contained in:
parent
4a76084dc9
commit
20da6e48c8
@ -93,20 +93,17 @@ class ProcessGroup:
|
|||||||
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
|
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
|
||||||
self._tp_rank_list.append(rank_id)
|
self._tp_rank_list.append(rank_id)
|
||||||
|
|
||||||
self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
|
||||||
self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
|
||||||
|
|
||||||
self._has_cpu_groups = False
|
self._has_cpu_groups = False
|
||||||
self._cpu_dp_process_group = None
|
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||||
self._cpu_tp_process_group = None
|
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||||
|
|
||||||
def set_cpu_groups(self):
|
def set_cpu_groups(self):
|
||||||
if self.has_cpu_groups:
|
if self.has_cpu_groups:
|
||||||
return
|
return
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||||
self._cpu_tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||||
self._cpu_dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_cpu_groups(self):
|
def has_cpu_groups(self):
|
||||||
@ -152,13 +149,15 @@ class ProcessGroup:
|
|||||||
return len(self._tp_rank_list)
|
return len(self._tp_rank_list)
|
||||||
|
|
||||||
def dp_process_group(self):
|
def dp_process_group(self):
|
||||||
return self._dp_process_group
|
# return self._dp_process_group
|
||||||
|
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||||
|
|
||||||
def tp_process_group(self):
|
def tp_process_group(self):
|
||||||
return self._tp_process_group
|
# return self._tp_process_group
|
||||||
|
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||||
|
|
||||||
def cpu_dp_process_group(self):
|
def cpu_dp_process_group(self):
|
||||||
return self._cpu_dp_process_group
|
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||||
|
|
||||||
def cpu_tp_process_group(self):
|
def cpu_tp_process_group(self):
|
||||||
return self._cpu_tp_process_group
|
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||||
|
@ -32,10 +32,15 @@ def save_checkpoint(dire: str,
|
|||||||
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
|
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
|
||||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||||
"""
|
"""
|
||||||
model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)}
|
model_state = {'epoch': epoch, 'model': model.state_dict()}
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||||
|
|
||||||
|
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
|
||||||
|
# 1. convert SHARD ColoTensor to REPLICATE
|
||||||
|
# only rank 0 saves the REPLICATE tensors.
|
||||||
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
|
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
|
||||||
|
|
||||||
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
|
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,6 +126,9 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
|
|||||||
model_reload = ColoDDP(model_reload, pg)
|
model_reload = ColoDDP(model_reload, pg)
|
||||||
model_ref = ColoDDP(model_ref, pg)
|
model_ref = ColoDDP(model_ref, pg)
|
||||||
|
|
||||||
|
init_spec_func(model, pg)
|
||||||
|
init_spec_func(model_ref, pg)
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||||
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
|
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
|
||||||
@ -135,23 +138,21 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
|
|||||||
weight_decay=0)
|
weight_decay=0)
|
||||||
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||||
|
|
||||||
|
lr_scheduler = None
|
||||||
if test_scheduler == 'colossalai_cosine_warmup':
|
if test_scheduler == 'colossalai_cosine_warmup':
|
||||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
|
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
|
||||||
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
|
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
|
||||||
total_steps=num_epoch,
|
total_steps=num_epoch,
|
||||||
warmup_steps=warmup_epoch)
|
warmup_steps=warmup_epoch)
|
||||||
|
|
||||||
elif test_scheduler == 'torch_cosine':
|
elif test_scheduler == 'torch_cosine':
|
||||||
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
|
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
|
||||||
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
|
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
|
||||||
|
|
||||||
elif test_scheduler == 'torch_lambda':
|
elif test_scheduler == 'torch_lambda':
|
||||||
lr_lambda = lambda epoch: 0.95
|
lr_lambda = lambda epoch: 0.95
|
||||||
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
|
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
|
||||||
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
|
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
|
||||||
|
else:
|
||||||
init_spec_func(model, pg)
|
raise TypeError(f"{test_scheduler} is invalid")
|
||||||
init_spec_func(model_ref, pg)
|
|
||||||
|
|
||||||
for epoch in range(0, num_epoch):
|
for epoch in range(0, num_epoch):
|
||||||
if epoch <= test_epoch:
|
if epoch <= test_epoch:
|
||||||
@ -212,7 +213,11 @@ def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
|
|||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, test_scheduler, pg)
|
run_checkpoint(init_1d_row_for_linear_weight_spec,
|
||||||
|
use_ddp,
|
||||||
|
test_epoch=test_epoch,
|
||||||
|
test_scheduler=test_scheduler,
|
||||||
|
pg=pg)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@ -236,4 +241,4 @@ def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_checkpoint(4, True, 1, 1)
|
test_checkpoint(4, True, 1, "colossalai_cosine_warmup")
|
||||||
|
Loading…
Reference in New Issue
Block a user