diff --git a/colossalai/nn/parallel.py b/colossalai/nn/parallel.py index 9b1fa81a2..b57fa31d5 100644 --- a/colossalai/nn/parallel.py +++ b/colossalai/nn/parallel.py @@ -3,8 +3,10 @@ import torch.distributed as dist from colossalai.core import global_context as gpc from colossalai.context import ParallelMode from functools import partial +from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 +from colossalai.tensor import ChunkManager, use_param_op_hooks, TensorState -__all__ = ['ColoDDP'] +__all__ = ['ColoDDP', 'ColoDDPV2'] def free_storage(data: torch.Tensor) -> None: @@ -76,3 +78,54 @@ class ColoDDP(torch.nn.Module): else: p._saved_grad.requires_grad_(False) p._saved_grad.zero_() + + +class ColoDDPV2(ColoDDP): + + def __init__(self, module: torch.nn.Module, chunk_manager: ChunkManager) -> None: + super().__init__(module) + self.chunk_manager = chunk_manager + self.param_op_hook = ZeROHookV2(chunk_manager) + self.fp32_params = [] + # TODO: get param order and filter unused params + for p in module.parameters(): + assert p.dtype == torch.half + fp32_p = p.float() + self.chunk_manager.append_tensor(p, 'fp16_param') + self.chunk_manager.append_tensor(fp32_p, 'fp32_param') + self.fp32_params.append(fp32_p) + + def forward(self, *args, **kwargs): + self.module.zero_grad(set_to_none=True) + for p, fp32_p in zip(self.module.parameters(), self.fp32_params): + if not self.chunk_manager.is_chunk_free(p): + self.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p) + with use_param_op_hooks(self.param_op_hook): + outputs = self.module(*args, **kwargs) + self.chunk_manager.exec_lazy_release() + return outputs + + def backward(self, loss: torch.Tensor): + with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook): + loss.backward() + self.chunk_manager.exec_lazy_release() + for p in self.module.parameters(): + if self.chunk_manager.is_chunk_free(p) or not p.requires_grad: + p.grad = None + else: + p.grad = p.data + + def grad_handle(self, p, grad): + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + with torch._C.DisableTorchFunction(): + self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) + if self.dp_world_size > 1: + grad = grad / self.dp_world_size + self.chunk_manager.copy_tensor_to_chunk_slice(p, grad) + self.chunk_manager.reduce_chunk(p) + self.chunk_manager.release_chunk(p) + return empty_grad + + def zero_grad(self, set_to_none: bool = False) -> None: + self.module.zero_grad(set_to_none=True) diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 12f572ed7..008183280 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -8,12 +8,14 @@ from ._ops import * from .optim.colo_optimizer import ColoOptimizer from . import distspec from .dist_spec_mgr import DistSpecManager +from .param_op_hook import ParamOpHook, use_param_op_hooks +from .chunk import ChunkManager, TensorState from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module from .modules import ColoLinear, ColoEmbedding __all__ = [ 'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager', - 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', - 'ColoLinear', 'ColoEmbedding' + 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoLinear', + 'ColoEmbedding', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState' ] diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py new file mode 100644 index 000000000..79a1f015d --- /dev/null +++ b/colossalai/tensor/chunk.py @@ -0,0 +1,264 @@ +import torch +import torch.distributed as dist +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Dict, Deque, Set, List +from collections import deque +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.utils import get_current_device + + +class TensorState(Enum): + FREE = 0 + COMPUTE = 1 + HOLD = 2 + HOLD_AFTER_BWD = 3 + READY_FOR_REDUCE = 4 + + +STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, + TensorState.HOLD)) + + +@dataclass +class TensorInfo: + state: TensorState + offset: int + end: int + + +class ChunkFullError(Exception): + pass + + +class Chunk: + + def __init__(self, + chunk_size: int, + src_rank: int, + dtype: torch.dtype, + init_device: Optional[torch.device] = None) -> None: + self.size = chunk_size + self.utilized_size = 0 + self.src_rank = src_rank + self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank + self.dtype = dtype + self.device = init_device or get_current_device() + self.data = torch.empty(chunk_size, dtype=dtype, device=self.device) + if not self.is_src_rank: + self.data.storage().resize_(0) + self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + + def append(self, tensor: torch.Tensor) -> None: + assert tensor.dtype == self.dtype + new_utilized_size = self.utilized_size + tensor.numel() + if new_utilized_size > self.size: + raise ChunkFullError + tensor_state = TensorState.FREE + if self.is_src_rank: + self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1)) + tensor_state = TensorState.HOLD + tensor.data = self.data[self.utilized_size:new_utilized_size].view(tensor.shape) + else: + tensor.storage().resize_(0) + self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) + self.utilized_size = new_utilized_size + + def release(self) -> None: + if not self.is_src_rank: + self.data.storage().resize_(0) + self._update_tensors_state(TensorState.FREE) + + def _update_tensors_ptr(self) -> None: + for tensor, tensor_info in self.tensors_info.items(): + tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): + for tensor_info in self.tensors_info.values(): + if prev_state is None or tensor_info.state == prev_state: + tensor_info.state = next_state + + def access(self) -> None: + if not self.is_src_rank: + self.data.storage().resize_(self.size) + self.data.data = self.data.to(get_current_device()) + dist.broadcast(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA)) + self._update_tensors_ptr() + if not self.is_src_rank: + self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) + + def move_device(self, device: torch.device) -> None: + self.data.data = self.data.to(device) + self._update_tensors_ptr() + + def reduce(self, is_all_reduce: bool = False) -> None: + self.data.data = self.data.to(get_current_device()) + if is_all_reduce: + dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) + else: + dist.reduce(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA)) + self._update_tensors_ptr() + self._update_tensors_state(TensorState.HOLD) + + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE' + # As the gradient hook can be triggered either before or after post-backward + # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce + # or compute -> ready_for_reduce -> hold_after_bwd + # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd + # this function only apply valid state transformation + # invalid calls will be ignored and nothing changes + if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: + # print( + # f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' + # ) + return + self.tensors_info[tensor].state = tensor_state + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + tensor_info = self.tensors_info[tensor] + self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1)) + tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape) + + @property + def can_release(self) -> bool: + for tensor_info in self.tensors_info.values(): + if tensor_info.state != TensorState.HOLD: + return False + return True + + @property + def can_move_device(self) -> bool: + for tensor_info in self.tensors_info.values(): + if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE): + return False + return True + + @property + def can_reduce(self) -> bool: + for tensor_info in self.tensors_info.values(): + if tensor_info.state != TensorState.READY_FOR_REDUCE: + return False + return True + + @property + def is_free(self) -> bool: + return self.data.storage().size() == 0 + + def __repr__(self) -> str: + return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_free}, tensor states={[info.state.name for info in self.tensors_info.values()]}' + + +class ChunkManager: + + def __init__(self, + chunk_size: Optional[int], + enable_distributed_storage: bool = False, + init_device: Optional[torch.device] = None) -> None: + assert chunk_size is None or chunk_size > 0 + self.chunk_size = chunk_size + self.enable_distributed_storage = enable_distributed_storage + self.device = init_device or get_current_device() + self.chunk_groups: Dict[str, Deque[Chunk]] = {} + self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {} + self.accessed_chunks: Set[Chunk] = set() + self.lazy_release_tensors: List[torch.Tensor] = [] + if enable_distributed_storage and chunk_size is None: + self.rank_load: Dict[str, torch.Tensor] = {} + + def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None: + assert tensor not in self.tensor_chunk_map + if self.chunk_size is not None and tensor.numel() > self.chunk_size: + raise ValueError( + f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})') + if group_name not in self.chunk_groups: + self.chunk_groups[group_name] = deque() + try: + self.chunk_groups[group_name][-1].append(tensor) + except (IndexError, ChunkFullError): + chunk_size = self.chunk_size or tensor.numel() + src_rank = self._get_next_src_rank(group_name) + chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device) + if self.enable_distributed_storage and self.chunk_size is None: + self.rank_load[group_name][src_rank] += chunk_size + self.chunk_groups[group_name].append(chunk) + chunk.append(tensor) + self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1] + if not self.enable_distributed_storage: + self.accessed_chunks.add(self.chunk_groups[group_name][-1]) + + def _get_next_src_rank(self, group_name: str) -> int: + if not self.enable_distributed_storage: + return gpc.get_local_rank(ParallelMode.DATA) + if self.chunk_size is None: + if group_name not in self.rank_load: + self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64) + src_rank = torch.argmin(self.rank_load[group_name]).item() + else: + chunk_idx = len(self.chunk_groups[group_name]) + src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA) + return src_rank + + def access_chunk(self, tensor: torch.Tensor) -> None: + chunk = self.tensor_chunk_map[tensor] + if chunk in self.accessed_chunks: + return + chunk.access() + self.accessed_chunks.add(chunk) + + def release_chunk(self, tensor: torch.Tensor) -> None: + if not self.enable_distributed_storage: + return + chunk = self.tensor_chunk_map[tensor] + if chunk not in self.accessed_chunks: + return + if chunk.can_release: + chunk.release() + self.accessed_chunks.remove(chunk) + + def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None: + chunk = self.tensor_chunk_map[tensor] + if chunk.can_move_device: + chunk.move_device(device) + + def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: + chunk = self.tensor_chunk_map[tensor] + chunk.tensor_trans_state(tensor, state) + + def reduce_chunk(self, tensor: torch.Tensor) -> None: + chunk = self.tensor_chunk_map[tensor] + if not chunk.can_reduce: + return + chunk.reduce(is_all_reduce=not self.enable_distributed_storage) + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: + chunk = self.tensor_chunk_map[tensor] + chunk.copy_tensor_to_chunk_slice(tensor, data) + + def is_chunk_free(self, tensor: torch.Tensor) -> bool: + chunk = self.tensor_chunk_map[tensor] + return chunk.is_free + + def get_chunk(self, tensor: torch.Tensor) -> Chunk: + return self.tensor_chunk_map[tensor] + + def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None: + self.lazy_release_tensors.extend(tensors) + + def exec_lazy_release(self) -> None: + for tensor in self.lazy_release_tensors: + self.release_chunk(tensor) + self.lazy_release_tensors.clear() + + def __repr__(self) -> str: + msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n' + for group_name, group in self.chunk_groups.items(): + msg += f'Group {group_name}:\n' + for i, chunk in enumerate(group): + msg += f'[{i}] {chunk}\n' + return msg diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 948356914..74eaa0b97 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -3,8 +3,10 @@ from .const import TensorType import torch from colossalai.tensor import TensorSpec, distspec from copy import copy +from .param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd from typing import Optional + class ColoParameter(ColoTensor, torch.nn.Parameter): r"""A kind of ColoTensor to be considered as a module parameter. @@ -44,6 +46,22 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): def __repr__(self): return f'ColoParameter: {torch.Tensor.__repr__(self)}' + @classmethod + def __torch_function__(cls, func, types, args=..., kwargs=None): + if len(_ParamOpHookWrapper.hooks) > 0: + if not func.__name__.startswith('__'): + params = list(filter(lambda arg: isinstance(arg, ColoParameter), args)) + if kwargs is not None: + params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values()))) + if len(params) > 0: + with torch._C.DisableTorchFunction(): + args = PreFwdPostBwd.apply(params, *args) + ret = super().__torch_function__(func, types, args, kwargs) + with torch._C.DisableTorchFunction(): + ret = PostFwdPreBwd.apply(params, ret) + return ret + return super().__torch_function__(func, types, args, kwargs) + def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] @@ -69,4 +87,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): # TODO(jzy) we don't support object reflection now. # distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`. raise NotImplementedError - diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py new file mode 100644 index 000000000..7522b62c2 --- /dev/null +++ b/colossalai/tensor/param_op_hook.py @@ -0,0 +1,71 @@ +import torch +from contextlib import contextmanager +from abc import ABC, abstractmethod +from typing import List, Tuple + + +class ParamOpHook(ABC): + + @abstractmethod + def pre_forward(self, params: List[torch.Tensor]) -> None: + pass + + @abstractmethod + def post_forward(self, params: List[torch.Tensor]) -> None: + pass + + @abstractmethod + def pre_backward(self, params: List[torch.Tensor]) -> None: + pass + + @abstractmethod + def post_backward(self, params: List[torch.Tensor]) -> None: + pass + + +class _ParamOpHookWrapper: + hooks: Tuple[ParamOpHook, ...] = tuple() + + +class PreFwdPostBwd(torch.autograd.Function): + + @staticmethod + def forward(ctx, params, *args): + ctx.params = params + for hook in _ParamOpHookWrapper.hooks: + hook.pre_forward(ctx.params) + if len(args) == 1: + return args[0] + return args + + @staticmethod + def backward(ctx, *grads): + for hook in _ParamOpHookWrapper.hooks: + hook.post_backward(ctx.params) + return (None,) + grads + + +class PostFwdPreBwd(torch.autograd.Function): + + @staticmethod + def forward(ctx, params, args): + ctx.params = params + for hook in _ParamOpHookWrapper.hooks: + hook.post_forward(params) + return args + + @staticmethod + def backward(ctx, *grads): + for hook in _ParamOpHookWrapper.hooks: + hook.pre_backward(ctx.params) + return (None,) + grads + + +@contextmanager +def use_param_op_hooks(*hooks: ParamOpHook): + try: + old_param_op_hooks = _ParamOpHookWrapper.hooks + _ParamOpHookWrapper.hooks = hooks + yield + finally: + _ParamOpHookWrapper.hooks = old_param_op_hooks diff --git a/colossalai/zero/utils/zero_hook_v2.py b/colossalai/zero/utils/zero_hook_v2.py new file mode 100644 index 000000000..e5c1619f4 --- /dev/null +++ b/colossalai/zero/utils/zero_hook_v2.py @@ -0,0 +1,57 @@ +import torch +from colossalai.tensor import ParamOpHook, ChunkManager, TensorState +from enum import Enum +from typing import List +from contextlib import contextmanager +from functools import partial + + +class TrainingPhase(Enum): + FORWARD = 0 + BACKWARD = 1 + + +class ZeROHookV2(ParamOpHook): + + def __init__(self, chunk_manager: ChunkManager) -> None: + super().__init__() + self._chunk_manager = chunk_manager + self._training_phase = TrainingPhase.FORWARD + + def pre_op(self, params): + for p in params: + self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) + self._chunk_manager.exec_lazy_release() + # TODO: evict chunks + for p in params: + self._chunk_manager.access_chunk(p) + + def post_op(self, params): + for p in params: + tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD + self._chunk_manager.trans_tensor_state(p, tensor_state) + self._chunk_manager.add_lazy_release_tensors(params) + + def pre_forward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_forward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + def pre_backward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_backward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + @contextmanager + def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + try: + old_training_phase = self._training_phase + self._training_phase = training_phase + yield + finally: + self._training_phase = old_training_phase + + switch_to_backward = switch_training_phase + switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD) diff --git a/tests/test_tensor/test_chunk.py b/tests/test_tensor/test_chunk.py new file mode 100644 index 000000000..f367753de --- /dev/null +++ b/tests/test_tensor/test_chunk.py @@ -0,0 +1,70 @@ +import torch +import colossalai +import pytest +import torch.multiprocessing as mp +from typing import List +from functools import partial +from colossalai.tensor import ChunkManager +from colossalai.testing import rerun_if_address_is_in_use, parameterize +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + + +def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]): + for p, has_tensor in zip(params, has_tensors): + if has_tensor: + assert p.storage().size() > 0 + assert p.device.type == 'cuda' + else: + assert p.storage().size() == 0 + + +# HAS_TENSORS[use_chunk][use_zero] +HAS_TENSORS = { + True: { + True: [[True, True, False], [False, False, True]], + False: [[True, True, True], [True, True, True]] + }, + False: { + True: [[True, False, True], [False, True, False]], + False: [[True, True, True], [True, True, True]] + } +} + + +@parameterize('use_chunk', [False, True]) +@parameterize('use_zero', [False, True]) +def run_chunk_zero(use_chunk, use_zero): + rank = gpc.get_local_rank(ParallelMode.DATA) + if rank == 0: + print(f'use_chunk={use_chunk}, use_zero={use_zero}') + params = [torch.rand(32, 32) for _ in range(3)] + chunk_size = 2048 if use_chunk else None + chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) + for p in params: + chunk_manager.append_tensor(p, 'param') + check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) + for p in params: + chunk_manager.access_chunk(p) + check_has_params(params, [True, True, True]) + for p in params: + chunk_manager.release_chunk(p) + check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chunk_zero() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_chunk_mapping(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_mapping(2) diff --git a/tests/test_tensor/test_zero.py b/tests/test_tensor/test_zero.py new file mode 100644 index 000000000..a87c73359 --- /dev/null +++ b/tests/test_tensor/test_zero.py @@ -0,0 +1,80 @@ +import pytest +import colossalai +from colossalai.context.parallel_mode import ParallelMode +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils import ColoInitContext +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec, ColoParameter, ChunkManager +from colossalai.core import global_context as gpc +from functools import partial +from _utils import tensor_equal, tensor_shard_equal, set_seed +from tests.components_to_test.registry import non_distributed_component_funcs +from torch.nn.parallel import DistributedDataParallel as DDP +from colossalai.nn.parallel import ColoDDP, ColoDDPV2 +from colossalai.testing import parameterize + + +def check_param_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + if p.storage().size() > 0: + assert tensor_equal(torch_p, p.float()), f'{torch_p} vs {p}' + + +def check_grad_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + if p.grad is not None: + assert tensor_equal(torch_p.grad, p.grad.float()) + + +@parameterize('use_chunk', [False, True]) +@parameterize('use_zero', [False, True]) +def run_gpt(use_chunk, use_zero): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + model = model.cuda() + torch_model = model_builder().cuda() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p) + model = model.half() + chunk_size = 38 * 1024**2 if use_chunk else None + chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) + model = ColoDDPV2(model, chunk_manager) + torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA)) + print(chunk_manager) + check_param_equal(model, torch_model) + model.train() + torch_model.train() + set_seed(gpc.get_local_rank(ParallelMode.DATA)) + for i, (input_ids, attn_mask) in enumerate(train_dataloader): + logits = model(input_ids, attn_mask) + torch_logits = torch_model(input_ids, attn_mask) + assert tensor_equal(torch_logits, logits.float()) + loss = criterion(logits, input_ids) + torch_loss = criterion(torch_logits, input_ids) + model.backward(loss) + torch_loss.backward() + check_grad_equal(model, torch_model) + break + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gpt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gpt(4)