mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[zero]fix zero ckptIO with offload (#4529)
* fix zero ckptio with offload * fix load device * saved tensors in ckpt should be on CPU * fix unit test * fix unit test * add clear cache * save memory for CI
This commit is contained in:
@@ -80,9 +80,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
# TODO:
|
||||
# 1. state_dict for checkpoint IO
|
||||
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
@@ -528,9 +525,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
working_param = self._param_store.master_to_working_param[id(param)]
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
|
||||
dist.all_gather(gather_tensor, v, group=self.dp_pg)
|
||||
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
|
||||
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
zero_state[param][k] = param_state
|
||||
|
||||
states_dict = self._pack_state(zero_state)
|
||||
@@ -553,7 +553,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
|
||||
device = 'cpu' if self._cpu_offload else 'cuda'
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
zero_state_dict = dict()
|
||||
@@ -585,9 +586,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v, group=self.dp_pg)
|
||||
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
|
||||
state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
|
||||
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
|
Reference in New Issue
Block a user