[zero]fix zero ckptIO with offload (#4529)

* fix zero ckptio with offload

* fix load device

* saved tensors in ckpt should be on CPU

* fix unit test

* fix unit test

* add clear cache

* save memory for CI
This commit is contained in:
LuGY
2023-09-01 17:41:19 +08:00
committed by GitHub
parent c7b60f7547
commit cbac782254
3 changed files with 22 additions and 16 deletions

View File

@@ -80,9 +80,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
# TODO:
# 1. state_dict for checkpoint IO
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
@@ -528,9 +525,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
working_param = self._param_store.master_to_working_param[id(param)]
gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(gather_tensor, v, group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
gather_tensor = [
torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
working_param).cpu()
zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state)
@@ -553,7 +553,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
device = 'cpu' if self._cpu_offload else 'cuda'
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach()
self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()
@@ -585,9 +586,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != 'step':
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v, group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
working_param).cpu()
current_block_size += state_tensor.numel()
current_block[k] = state_tensor