[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)

* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12

* [zero] add cpu shard init

* [zero] add tiny example test

* [colo_tensor] fix bugs for torch-1.11
This commit is contained in:
HELSON
2022-11-02 16:11:34 +08:00
committed by GitHub
parent 32c1b843a9
commit c6a1a62636
9 changed files with 1041 additions and 951 deletions

View File

@@ -1,15 +1,17 @@
from enum import Enum
from typing import Dict, Set, Tuple
import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from torch.nn import Parameter
from colossalai.nn.parallel.data_parallel import ZeroDDP
from typing import Dict, Tuple, Set
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device
class OptimState(Enum):
@@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
return param_info.offset, param_info.end
begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end