mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[tensor] ColoTensor supports ZeRo (#1015)
* impl chunk manager * impl param op hook * add reduce_chunk * add zero hook v2 * add zero dp * fix TensorInfo * impl load balancing when using zero without chunk * fix zero hook * polish chunk * fix bugs * ddp ok * zero ok * polish code * fix bugs about load balancing * polish code * polish code * add ene-to-end test * polish code * polish code * polish code * fix typo * add test_chunk * fix bugs * fix bugs * polish code
This commit is contained in:
@@ -8,12 +8,14 @@ from ._ops import *
|
||||
from .optim.colo_optimizer import ColoOptimizer
|
||||
from . import distspec
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
from .param_op_hook import ParamOpHook, use_param_op_hooks
|
||||
from .chunk import ChunkManager, TensorState
|
||||
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
||||
from .modules import ColoLinear, ColoEmbedding
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
|
||||
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager',
|
||||
'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
||||
'ColoLinear', 'ColoEmbedding'
|
||||
'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoLinear',
|
||||
'ColoEmbedding', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState'
|
||||
]
|
||||
|
||||
264
colossalai/tensor/chunk.py
Normal file
264
colossalai/tensor/chunk.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Deque, Set, List
|
||||
from collections import deque
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
COMPUTE = 1
|
||||
HOLD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
||||
TensorState.HOLD))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
state: TensorState
|
||||
offset: int
|
||||
end: int
|
||||
|
||||
|
||||
class ChunkFullError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Chunk:
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
src_rank: int,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None) -> None:
|
||||
self.size = chunk_size
|
||||
self.utilized_size = 0
|
||||
self.src_rank = src_rank
|
||||
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
||||
self.dtype = dtype
|
||||
self.device = init_device or get_current_device()
|
||||
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
|
||||
if not self.is_src_rank:
|
||||
self.data.storage().resize_(0)
|
||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||
|
||||
def append(self, tensor: torch.Tensor) -> None:
|
||||
assert tensor.dtype == self.dtype
|
||||
new_utilized_size = self.utilized_size + tensor.numel()
|
||||
if new_utilized_size > self.size:
|
||||
raise ChunkFullError
|
||||
tensor_state = TensorState.FREE
|
||||
if self.is_src_rank:
|
||||
self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
|
||||
tensor_state = TensorState.HOLD
|
||||
tensor.data = self.data[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
else:
|
||||
tensor.storage().resize_(0)
|
||||
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
||||
self.utilized_size = new_utilized_size
|
||||
|
||||
def release(self) -> None:
|
||||
if not self.is_src_rank:
|
||||
self.data.storage().resize_(0)
|
||||
self._update_tensors_state(TensorState.FREE)
|
||||
|
||||
def _update_tensors_ptr(self) -> None:
|
||||
for tensor, tensor_info in self.tensors_info.items():
|
||||
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if prev_state is None or tensor_info.state == prev_state:
|
||||
tensor_info.state = next_state
|
||||
|
||||
def access(self) -> None:
|
||||
if not self.is_src_rank:
|
||||
self.data.storage().resize_(self.size)
|
||||
self.data.data = self.data.to(get_current_device())
|
||||
dist.broadcast(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||
self._update_tensors_ptr()
|
||||
if not self.is_src_rank:
|
||||
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
||||
|
||||
def move_device(self, device: torch.device) -> None:
|
||||
self.data.data = self.data.to(device)
|
||||
self._update_tensors_ptr()
|
||||
|
||||
def reduce(self, is_all_reduce: bool = False) -> None:
|
||||
self.data.data = self.data.to(get_current_device())
|
||||
if is_all_reduce:
|
||||
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
||||
else:
|
||||
dist.reduce(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||
self._update_tensors_ptr()
|
||||
self._update_tensors_state(TensorState.HOLD)
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||
assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE'
|
||||
# As the gradient hook can be triggered either before or after post-backward
|
||||
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
||||
# or compute -> ready_for_reduce -> hold_after_bwd
|
||||
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
|
||||
# this function only apply valid state transformation
|
||||
# invalid calls will be ignored and nothing changes
|
||||
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
|
||||
# print(
|
||||
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
|
||||
# )
|
||||
return
|
||||
self.tensors_info[tensor].state = tensor_state
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
|
||||
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if tensor_info.state != TensorState.HOLD:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_move_device(self) -> bool:
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_reduce(self) -> bool:
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if tensor_info.state != TensorState.READY_FOR_REDUCE:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_free(self) -> bool:
|
||||
return self.data.storage().size() == 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_free}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: Optional[int],
|
||||
enable_distributed_storage: bool = False,
|
||||
init_device: Optional[torch.device] = None) -> None:
|
||||
assert chunk_size is None or chunk_size > 0
|
||||
self.chunk_size = chunk_size
|
||||
self.enable_distributed_storage = enable_distributed_storage
|
||||
self.device = init_device or get_current_device()
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.lazy_release_tensors: List[torch.Tensor] = []
|
||||
if enable_distributed_storage and chunk_size is None:
|
||||
self.rank_load: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||
raise ValueError(
|
||||
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
|
||||
if group_name not in self.chunk_groups:
|
||||
self.chunk_groups[group_name] = deque()
|
||||
try:
|
||||
self.chunk_groups[group_name][-1].append(tensor)
|
||||
except (IndexError, ChunkFullError):
|
||||
chunk_size = self.chunk_size or tensor.numel()
|
||||
src_rank = self._get_next_src_rank(group_name)
|
||||
chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device)
|
||||
if self.enable_distributed_storage and self.chunk_size is None:
|
||||
self.rank_load[group_name][src_rank] += chunk_size
|
||||
self.chunk_groups[group_name].append(chunk)
|
||||
chunk.append(tensor)
|
||||
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
|
||||
if not self.enable_distributed_storage:
|
||||
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
|
||||
|
||||
def _get_next_src_rank(self, group_name: str) -> int:
|
||||
if not self.enable_distributed_storage:
|
||||
return gpc.get_local_rank(ParallelMode.DATA)
|
||||
if self.chunk_size is None:
|
||||
if group_name not in self.rank_load:
|
||||
self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64)
|
||||
src_rank = torch.argmin(self.rank_load[group_name]).item()
|
||||
else:
|
||||
chunk_idx = len(self.chunk_groups[group_name])
|
||||
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
|
||||
return src_rank
|
||||
|
||||
def access_chunk(self, tensor: torch.Tensor) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
chunk.access()
|
||||
self.accessed_chunks.add(chunk)
|
||||
|
||||
def release_chunk(self, tensor: torch.Tensor) -> None:
|
||||
if not self.enable_distributed_storage:
|
||||
return
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
chunk.release()
|
||||
self.accessed_chunks.remove(chunk)
|
||||
|
||||
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if chunk.can_move_device:
|
||||
chunk.move_device(device)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, tensor: torch.Tensor) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if not chunk.can_reduce:
|
||||
return
|
||||
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
||||
|
||||
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
return chunk.is_free
|
||||
|
||||
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||
return self.tensor_chunk_map[tensor]
|
||||
|
||||
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
|
||||
self.lazy_release_tensors.extend(tensors)
|
||||
|
||||
def exec_lazy_release(self) -> None:
|
||||
for tensor in self.lazy_release_tensors:
|
||||
self.release_chunk(tensor)
|
||||
self.lazy_release_tensors.clear()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n'
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg += f'Group {group_name}:\n'
|
||||
for i, chunk in enumerate(group):
|
||||
msg += f'[{i}] {chunk}\n'
|
||||
return msg
|
||||
@@ -3,8 +3,10 @@ from .const import TensorType
|
||||
import torch
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from copy import copy
|
||||
from .param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
r"""A kind of ColoTensor to be considered as a module parameter.
|
||||
|
||||
@@ -44,6 +46,22 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
def __repr__(self):
|
||||
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||
if len(_ParamOpHookWrapper.hooks) > 0:
|
||||
if not func.__name__.startswith('__'):
|
||||
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
|
||||
if kwargs is not None:
|
||||
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
args = PreFwdPostBwd.apply(params, *args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = PostFwdPreBwd.apply(params, ret)
|
||||
return ret
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
@@ -69,4 +87,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
# TODO(jzy) we don't support object reflection now.
|
||||
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
71
colossalai/tensor/param_op_hook.py
Normal file
71
colossalai/tensor/param_op_hook.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_forward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_backward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_backward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _ParamOpHookWrapper:
|
||||
hooks: Tuple[ParamOpHook, ...] = tuple()
|
||||
|
||||
|
||||
class PreFwdPostBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, params, *args):
|
||||
ctx.params = params
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.pre_forward(ctx.params)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.post_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
class PostFwdPreBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, params, args):
|
||||
ctx.params = params
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.post_forward(params)
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.pre_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_param_op_hooks(*hooks: ParamOpHook):
|
||||
try:
|
||||
old_param_op_hooks = _ParamOpHookWrapper.hooks
|
||||
_ParamOpHookWrapper.hooks = hooks
|
||||
yield
|
||||
finally:
|
||||
_ParamOpHookWrapper.hooks = old_param_op_hooks
|
||||
Reference in New Issue
Block a user