mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)
* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12 * [zero] add cpu shard init * [zero] add tiny example test * [colo_tensor] fix bugs for torch-1.11
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from enum import Enum
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn import Parameter
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from typing import Dict, Tuple, Set
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from colossalai.utils import disposable, get_current_device
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
@@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
||||
param_info = local_chunk.tensors_info[local_param]
|
||||
if local_chunk.keep_gathered:
|
||||
return param_info.offset, param_info.end
|
||||
begin = max(0, param_info.offset - local_chunk.shard_begin)
|
||||
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
||||
return begin, end
|
||||
|
Reference in New Issue
Block a user