diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 24ebae1c7..db26269b4 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -780,19 +780,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper): zero_state = dict() device = get_accelerator().get_current_device() for param, state in self.optim.state.items(): - if pinned_state_dicts and param not in pinned_state_dicts: + if pinned_state_dicts is not None and param not in pinned_state_dicts: pinned_state_dicts[param] = {} zero_state[param] = copy.deepcopy(state) for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": - if pinned_state_dicts and k not in pinned_state_dicts[param]: - pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu") working_param = self.master_to_working_param[id(param)] pg = self.param_to_pg[working_param] gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param) - if pinned_state_dicts: + if pinned_state_dicts is not None and k not in pinned_state_dicts[param]: + pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu") + if pinned_state_dicts is not None: pinned_state_dicts[param][k].copy_(param_state) zero_state[param][k] = pinned_state_dicts[param][k] else: @@ -858,7 +858,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for param_idx, states in local_states.items(): current_block_size = 0 current_block = copy.deepcopy(states) - if pinned_state_dicts and param_idx not in pinned_state_dicts: + if pinned_state_dicts is not None and param_idx not in pinned_state_dicts: pinned_state_dicts[param_idx] = {} master_param = idx2master[param_idx] working_param = self.master_to_working_param[id(master_param)] @@ -869,9 +869,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) - if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: + if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]: pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu") - if pinned_state_dicts: + if pinned_state_dicts is not None: pinned_state_dicts[param_idx][k].copy_(state_tensor) current_block[k] = pinned_state_dicts[param_idx][k] else: