mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
5
colossalai/legacy/zero/shard_utils/__init__.py
Normal file
5
colossalai/legacy/zero/shard_utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .base_shard_strategy import BaseShardStrategy
|
||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
22
colossalai/legacy/zero/shard_utils/base_shard_strategy.py
Normal file
22
colossalai/legacy/zero/shard_utils/base_shard_strategy.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class BaseShardStrategy(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
pass
|
@@ -0,0 +1,47 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors as flatten
|
||||
|
||||
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
|
||||
class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
|
||||
which will fully utilize network bandwidth.
|
||||
It is especially useful when sub-module contains bias,
|
||||
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usually small).
|
||||
"""
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
|
||||
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
||||
if len(tensor_list) == 0:
|
||||
return
|
||||
target_device = tensor_list[0].device
|
||||
dtype = tensor_list[0].dtype
|
||||
buffer_list: List[torch.Tensor] = []
|
||||
tensor_numels = [t.payload.numel() for t in tensor_list]
|
||||
buffer_size = sum(tensor_numels)
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
||||
else:
|
||||
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
|
||||
# Move to target device before splitting buffer
|
||||
# Ensure we utilize maximum PCIE bandwidth
|
||||
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
|
||||
offset = 0
|
||||
for i, t in enumerate(tensor_list):
|
||||
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
|
||||
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
|
||||
t.payload_reset(gathered_payload)
|
||||
t.is_sharded = False
|
||||
offset += tensor_numels[i]
|
22
colossalai/legacy/zero/shard_utils/commons.py
Normal file
22
colossalai/legacy/zero/shard_utils/commons.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
|
||||
"""Return the local shard of a full tensor."""
|
||||
# Shard using torch.chunk to match all-gather/reduce-scatter.
|
||||
chunks = list(torch.flatten(tensor).chunk(world_size))
|
||||
while len(chunks) < world_size:
|
||||
chunks.append(chunks[0].new_empty(0))
|
||||
|
||||
# Determine number of padding elements.
|
||||
num_to_pad = chunks[0].numel() - chunks[rank].numel()
|
||||
assert num_to_pad >= 0, num_to_pad
|
||||
|
||||
shard = torch.zeros_like(chunks[0])
|
||||
length = chunks[rank].size(0)
|
||||
shard_temp = shard[:length]
|
||||
shard_temp.copy_(chunks[rank])
|
||||
|
||||
return shard, num_to_pad
|
59
colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
Normal file
59
colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.legacy.zero.shard_utils.commons import get_shard
|
||||
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
"""
|
||||
A naive implementation which shard each tensor evenly over all ranks
|
||||
"""
|
||||
|
||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
for t in tensor_list:
|
||||
self._shard_tensor(t, process_group)
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
for t in tensor_list:
|
||||
self._gather_tensor(t, process_group)
|
||||
|
||||
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
""" Shard tensor among processes.
|
||||
|
||||
Args:
|
||||
t (ShardedTensor): a tensor to be sharded.
|
||||
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
|
||||
Defaults to None.
|
||||
"""
|
||||
if t.is_sharded:
|
||||
return
|
||||
if t.payload.device.type == 'cuda':
|
||||
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||
f" but current cuda device is {get_current_device()}"
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.payload_reset(sharded_payload)
|
||||
t.is_sharded = True
|
||||
|
||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
if not t.is_sharded:
|
||||
return
|
||||
target_device = t.device
|
||||
payload_numel = t.payload.numel()
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
|
||||
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
|
||||
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
|
||||
buffer_list[rank].copy_(t.payload)
|
||||
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
||||
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||
t.payload_reset(gathered_payload)
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t.is_sharded = False
|
Reference in New Issue
Block a user