mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)
* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12 * [zero] add cpu shard init * [zero] add tiny example test * [colo_tensor] fix bugs for torch-1.11
This commit is contained in:
@@ -1,19 +1,22 @@
|
||||
import torch
|
||||
import itertools
|
||||
import torch.distributed as dist
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from .reducer import Reducer
|
||||
from functools import partial
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
|
||||
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
|
||||
from .reducer import Reducer
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
@@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
|
||||
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
assert isinstance(p, ColoParameter)
|
||||
@@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP):
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
p.data = p.data.half()
|
||||
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
|
||||
self.chunk_manager.append_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.append_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
self.chunk_manager.close_all_groups()
|
||||
@@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP):
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk_16.keep_gathered:
|
||||
self.grads_device[p] = get_current_device()
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
Reference in New Issue
Block a user