mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[checkpointio] optimize zero optim checkpoint io (#4591)
* [zero] update checkpoint io to save memory * [checkpointio] add device map to save memory
This commit is contained in:
@@ -307,7 +307,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# or got a grad of param from another group
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
|
||||
group_id != self._bucket_store.current_group_id:
|
||||
group_id != self._bucket_store.current_group_id:
|
||||
self._run_reduction()
|
||||
|
||||
padding_size = self._param_store.get_param_padding_size(param)
|
||||
@@ -553,11 +553,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
device = 'cpu' if self._cpu_offload else 'cuda'
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach()
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
zero_state_dict = dict()
|
||||
|
||||
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
Reference in New Issue
Block a user