[checkpointio] optimize zero optim checkpoint io (#4591)

* [zero] update checkpoint io to save memory

* [checkpointio] add device map to save memory
This commit is contained in:
Hongxin Liu
2023-09-04 11:26:45 +08:00
committed by GitHub
parent cfa607080f
commit 63ecafb1fb
4 changed files with 43 additions and 22 deletions

View File

@@ -307,7 +307,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
group_id != self._bucket_store.current_group_id:
group_id != self._bucket_store.current_group_id:
self._run_reduction()
padding_size = self._param_store.get_param_padding_size(param)
@@ -553,11 +553,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
device = 'cpu' if self._cpu_offload else 'cuda'
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach()
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.