mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[zero] add ZeroTensorShardStrategy (#793)
This commit is contained in:
1
colossalai/zero/comm/__init__.py
Normal file
1
colossalai/zero/comm/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .zero_comm import ZeroDist
|
46
colossalai/zero/comm/zero_comm.py
Normal file
46
colossalai/zero/comm/zero_comm.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.utils import get_current_device
|
||||
from typing import Optional
|
||||
|
||||
ZERO_USE_NCCL = False
|
||||
try:
|
||||
import colossal_zero_comm
|
||||
ZERO_USE_NCCL = True
|
||||
except ImportError:
|
||||
print("Please pip reinstall Colossalai.")
|
||||
|
||||
|
||||
class ZeroCommWorld(metaclass=SingletonMeta):
|
||||
"""Zero communicator, used for communications in zero parallel.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.zero_pg: Optional[ProcessGroup] = None
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self.zero_pg is not None
|
||||
|
||||
def zero_comm_init(self, comm_group: ProcessGroup):
|
||||
if not ZERO_USE_NCCL:
|
||||
return
|
||||
|
||||
if self.is_initialized:
|
||||
assert self.zero_pg == comm_group, "Cant not initialize zero group twice"
|
||||
return
|
||||
|
||||
self.zero_pg = comm_group
|
||||
colossal_zero_comm.create_comm(self.zero_pg, get_current_device())
|
||||
|
||||
def zero_all_gather(self, input_tensor: torch.Tensor):
|
||||
assert self.zero_pg is not None, "Please initialize zero communication world first"
|
||||
rank = dist.get_rank(self.zero_pg)
|
||||
world_size = self.zero_pg.size()
|
||||
colossal_zero_comm.inplace_all_gather(input_tensor, rank, world_size)
|
||||
|
||||
|
||||
ZeroDist = ZeroCommWorld()
|
@@ -12,6 +12,7 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.comm import ZeroDist
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
|
||||
@@ -191,6 +192,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
The Callback function when entering the context
|
||||
"""
|
||||
self.logger = get_dist_logger("ZeroInitContext")
|
||||
ZeroDist.zero_comm_init(self.dp_process_group) # initialize zero communication world
|
||||
|
||||
# substitute fan-in and fan-out calculation
|
||||
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from .base_shard_strategy import BaseShardStrategy
|
||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
from .zero_tensor_shard_strategy import ZeroTensorShardStrategy
|
||||
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'ZeroTensorShardStrategy']
|
||||
|
38
colossalai/zero/shard_utils/zero_tensor_shard_strategy.py
Normal file
38
colossalai/zero/shard_utils/zero_tensor_shard_strategy.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.comm import ZeroDist
|
||||
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
|
||||
class ZeroTensorShardStrategy(TensorShardStrategy):
|
||||
"""Use the same shard scheme as `TensorShardStrategy`'s.
|
||||
But its all-gather operation is in-place, meaning that no extra buffer is created.
|
||||
Extra buffer is created when using `torch.distributed.all_gather`.
|
||||
This can reduce peak memory used in zero-offload.
|
||||
You should notice that this strategy is highly coupled with zero.
|
||||
You can not change its communication group and must use ZeroContext to create your model.
|
||||
"""
|
||||
|
||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
if not t.is_sharded:
|
||||
return
|
||||
target_device = t.device
|
||||
payload_numel = t.payload.numel()
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
|
||||
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
|
||||
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
|
||||
buffer_list[rank].copy_(t.payload)
|
||||
|
||||
ZeroDist.zero_all_gather(buffer) # notice: process_group is useless here
|
||||
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||
t.reset_payload(gathered_payload)
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t.is_sharded = False
|
Reference in New Issue
Block a user