From 5ff5323538ccb977ba18c161560af82011a9480a Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 14 Feb 2025 15:09:50 +0800 Subject: [PATCH] [hotfix] fix zero optim save (#6191) --- colossalai/zero/low_level/low_level_optim.py | 124 ++++++++++--------- 1 file changed, 67 insertions(+), 57 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8abaf8fc6..c530ff009 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -786,30 +786,36 @@ class LowLevelZeroOptimizer(OptimizerWrapper): """ zero_state = dict() device = get_accelerator().get_current_device() - for param, state in self.optim.state.items(): - working_param = self.master_to_working_param[id(param)] - pg = self.param_to_pg[working_param] - if not only_on_master or get_nd_rank(pg) == 0: - zero_state[param] = copy.deepcopy(state) - else: - zero_state[param] = {} + for param_group in self.optim.param_groups: + for param in param_group["params"]: + if param not in self.optim.state: + continue + state = self.optim.state[param] + working_param = self.master_to_working_param[id(param)] + pg = self.param_to_pg[working_param] + if not only_on_master or get_nd_rank(pg) == 0: + zero_state[param] = copy.deepcopy(state) + else: + zero_state[param] = {} - if pinned_state_dicts is not None and param not in pinned_state_dicts: - pinned_state_dicts[param] = {} + if pinned_state_dicts is not None and param not in pinned_state_dicts: + pinned_state_dicts[param] = {} - for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != "step": - gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) - all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) - param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param) - if not only_on_master or get_nd_rank(pg) == 0: - if pinned_state_dicts is not None and k not in pinned_state_dicts[param]: - pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu") - if pinned_state_dicts is not None: - pinned_state_dicts[param][k].copy_(param_state) - zero_state[param][k] = pinned_state_dicts[param][k] - else: - zero_state[param][k] = param_state.cpu() + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) + all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) + param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param) + if not only_on_master or get_nd_rank(pg) == 0: + if pinned_state_dicts is not None and k not in pinned_state_dicts[param]: + pinned_state_dicts[param][k] = torch.empty_like( + param_state, pin_memory=True, device="cpu" + ) + if pinned_state_dicts is not None: + pinned_state_dicts[param][k].copy_(param_state) + zero_state[param][k] = pinned_state_dicts[param][k] + else: + zero_state[param][k] = param_state.cpu() states_dict = self._pack_state(zero_state) @@ -865,48 +871,52 @@ class LowLevelZeroOptimizer(OptimizerWrapper): device = get_accelerator().get_current_device() local_states = self.optim.state_dict()["state"] - idx2master = {} + master2idx = {} cnt = 0 for param_group in self.optim.param_groups: for param in param_group["params"]: - idx2master[cnt] = param + master2idx[param] = cnt cnt += 1 - for param_idx, states in local_states.items(): - current_block_size = 0 - if pinned_state_dicts is not None and param_idx not in pinned_state_dicts: - pinned_state_dicts[param_idx] = {} - master_param = idx2master[param_idx] - working_param = self.master_to_working_param[id(master_param)] - pg = self.param_to_pg[working_param] - if not only_on_master or get_nd_rank(pg) == 0: - current_block = copy.deepcopy(states) - else: - current_block = {} - for k, v in states.items(): - if isinstance(v, torch.Tensor) and k != "step": - state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) - all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) - state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) - if not only_on_master or get_nd_rank(pg) == 0: - if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]: - pinned_state_dicts[param_idx][k] = torch.empty_like( - state_tensor, pin_memory=True, device="cpu" - ) - if pinned_state_dicts is not None: - pinned_state_dicts[param_idx][k].copy_(state_tensor) - current_block[k] = pinned_state_dicts[param_idx][k] - else: - current_block[k] = state_tensor.cpu() - current_block_size += calculate_tensor_size(state_tensor) + for param_group in self.optim.param_groups: + for master_param in param_group["params"]: + param_idx = master2idx[master_param] + states = local_states[param_idx] - if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: - yield ret_block, ret_block_size - ret_block = dict() - ret_block_size = 0 + current_block_size = 0 + if pinned_state_dicts is not None and param_idx not in pinned_state_dicts: + pinned_state_dicts[param_idx] = {} + working_param = self.master_to_working_param[id(master_param)] + pg = self.param_to_pg[working_param] + if not only_on_master or get_nd_rank(pg) == 0: + current_block = copy.deepcopy(states) + else: + current_block = {} - ret_block[param_idx] = current_block - ret_block_size += current_block_size + for k, v in states.items(): + if isinstance(v, torch.Tensor) and k != "step": + state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) + all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) + state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) + if not only_on_master or get_nd_rank(pg) == 0: + if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]: + pinned_state_dicts[param_idx][k] = torch.empty_like( + state_tensor, pin_memory=True, device="cpu" + ) + if pinned_state_dicts is not None: + pinned_state_dicts[param_idx][k].copy_(state_tensor) + current_block[k] = pinned_state_dicts[param_idx][k] + else: + current_block[k] = state_tensor.cpu() + current_block_size += calculate_tensor_size(state_tensor) + + if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: + yield ret_block, ret_block_size + ret_block = dict() + ret_block_size = 0 + + ret_block[param_idx] = current_block + ret_block_size += current_block_size yield ret_block, ret_block_size