[zero] add ZeroTensorShardStrategy (#793)

This commit is contained in:
HELSON
2022-04-19 14:32:45 +08:00
committed by GitHub
parent 681addb512
commit 88759e289e
12 changed files with 214 additions and 11 deletions

View File

@@ -0,0 +1 @@
from .zero_comm import ZeroDist

View File

@@ -0,0 +1,46 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils import get_current_device
from typing import Optional
ZERO_USE_NCCL = False
try:
import colossal_zero_comm
ZERO_USE_NCCL = True
except ImportError:
print("Please pip reinstall Colossalai.")
class ZeroCommWorld(metaclass=SingletonMeta):
"""Zero communicator, used for communications in zero parallel.
"""
def __init__(self):
super().__init__()
self.zero_pg: Optional[ProcessGroup] = None
@property
def is_initialized(self):
return self.zero_pg is not None
def zero_comm_init(self, comm_group: ProcessGroup):
if not ZERO_USE_NCCL:
return
if self.is_initialized:
assert self.zero_pg == comm_group, "Cant not initialize zero group twice"
return
self.zero_pg = comm_group
colossal_zero_comm.create_comm(self.zero_pg, get_current_device())
def zero_all_gather(self, input_tensor: torch.Tensor):
assert self.zero_pg is not None, "Please initialize zero communication world first"
rank = dist.get_rank(self.zero_pg)
world_size = self.zero_pg.size()
colossal_zero_comm.inplace_all_gather(input_tensor, rank, world_size)
ZeroDist = ZeroCommWorld()

View File

@@ -12,6 +12,7 @@ from colossalai.logging import get_dist_logger
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.zero.comm import ZeroDist
from contextlib import AbstractContextManager
@@ -191,6 +192,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The Callback function when entering the context
"""
self.logger = get_dist_logger("ZeroInitContext")
ZeroDist.zero_comm_init(self.dp_process_group) # initialize zero communication world
# substitute fan-in and fan-out calculation
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out

View File

@@ -1,5 +1,6 @@
from .base_shard_strategy import BaseShardStrategy
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
from .tensor_shard_strategy import TensorShardStrategy
from .zero_tensor_shard_strategy import ZeroTensorShardStrategy
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'ZeroTensorShardStrategy']

View File

@@ -0,0 +1,38 @@
from typing import Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.comm import ZeroDist
from .tensor_shard_strategy import TensorShardStrategy
class ZeroTensorShardStrategy(TensorShardStrategy):
"""Use the same shard scheme as `TensorShardStrategy`'s.
But its all-gather operation is in-place, meaning that no extra buffer is created.
Extra buffer is created when using `torch.distributed.all_gather`.
This can reduce peak memory used in zero-offload.
You should notice that this strategy is highly coupled with zero.
You can not change its communication group and must use ZeroContext to create your model.
"""
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
if not t.is_sharded:
return
target_device = t.device
payload_numel = t.payload.numel()
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
buffer_list[rank].copy_(t.payload)
ZeroDist.zero_all_gather(buffer) # notice: process_group is useless here
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload)
colo_model_data_tensor_move_inline(t, target_device)
t.is_sharded = False