diff --git a/colossalai/gemini/chunk.py b/colossalai/gemini/chunk.py index a5a7ae027..c39a06502 100644 --- a/colossalai/gemini/chunk.py +++ b/colossalai/gemini/chunk.py @@ -4,9 +4,8 @@ from dataclasses import dataclass from enum import Enum from typing import Optional, Dict, List -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode from colossalai.utils import get_current_device +from colossalai.tensor import ProcessGroup as ColoProcessGroup class TensorState(Enum): @@ -65,14 +64,16 @@ class Chunk: def __init__(self, chunk_size: int, src_rank: int, + process_group: ColoProcessGroup, dtype: torch.dtype, init_device: Optional[torch.device] = None, force_data_on_cuda: bool = False) -> None: self.size = chunk_size self.utilized_size = 0 self.src_rank = src_rank - self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank - self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank] + self.process_group = process_group + self.is_src_rank = process_group.dp_local_rank() == src_rank + self.global_src_rank = process_group.get_ranks_in_dp()[src_rank] self.dtype = dtype device = init_device or get_current_device() if force_data_on_cuda: @@ -150,7 +151,7 @@ class Chunk: if not self.is_src_rank: alloc_storage(self._payload) self.move_device(get_current_device(), update_ptr=False) - dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) + dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group()) # update tensor meta info self._update_tensors_ptr() @@ -193,9 +194,9 @@ class Chunk: """ self.move_device(get_current_device(), update_ptr=False) if is_all_reduce: - dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) + dist.all_reduce(self.data, group=self.process_group.dp_process_group()) else: - dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) + dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group()) self._update_tensors_ptr() self._update_tensors_state(TensorState.HOLD) @@ -216,7 +217,7 @@ class Chunk: # invalid calls will be ignored and nothing changes if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: # print( - # f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' + # f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' # ) return self.tensors_info[tensor].state = tensor_state diff --git a/colossalai/gemini/chunk_mgr.py b/colossalai/gemini/chunk_mgr.py index 6d6c8b47f..42392d76d 100644 --- a/colossalai/gemini/chunk_mgr.py +++ b/colossalai/gemini/chunk_mgr.py @@ -2,9 +2,8 @@ import torch from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from collections import deque -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc from colossalai.utils import get_current_device +from colossalai.tensor import ProcessGroup as ColoProcessGroup from .chunk import Chunk, ChunkFullError, TensorState @@ -20,10 +19,13 @@ class ChunkManager: def __init__(self, chunk_size: Optional[int], + process_group: ColoProcessGroup, enable_distributed_storage: bool = False, init_device: Optional[torch.device] = None) -> None: assert chunk_size is None or chunk_size > 0 + assert isinstance(process_group, ColoProcessGroup) self.chunk_size = chunk_size + self.process_group = process_group self.enable_distributed_storage = enable_distributed_storage self.device = init_device or get_current_device() self.chunk_groups: Dict[str, Deque[Chunk]] = {} @@ -69,6 +71,7 @@ class ChunkManager: src_rank = self._get_next_src_rank(group_name) chunk = Chunk(chunk_size, src_rank, + self.process_group, tensor.dtype, self.device, force_data_on_cuda=self.groups_force_data_on_cuda[group_name]) @@ -89,17 +92,17 @@ class ChunkManager: def _get_next_src_rank(self, group_name: str) -> int: if not self.enable_distributed_storage: # the chunk is owned by the current rank if no distributed storage is enabled - return gpc.get_local_rank(ParallelMode.DATA) + return self.process_group.dp_local_rank() if self.chunk_size is None: if group_name not in self.rank_load: - self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64) + self.rank_load[group_name] = torch.zeros(self.process_group.dp_world_size(), dtype=torch.int64) # the process owning the tensor will be the process with the smallest number of elements src_rank = torch.argmin(self.rank_load[group_name]).item() else: # chunk is owned by processes in a round-robin fashion chunk_idx = len(self.chunk_groups[group_name]) - src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA) + src_rank = chunk_idx % self.process_group.dp_world_size() return src_rank def access_chunk(self, chunk: Chunk) -> None: @@ -222,7 +225,7 @@ class ChunkManager: self.lazy_release_tensors.clear() def __repr__(self) -> str: - msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n' + msg = f'Rank {self.process_group.dp_local_rank()}:\n' msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' for group_name, group in self.chunk_groups.items(): msg += f'Group {group_name}:\n' diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 9de510e3e..72d6d053f 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -118,7 +118,7 @@ class ColoDDP(torch.nn.Module): return empty_grad else: - #TODO(jiaruifang) fixme + # TODO(jiaruifang) fixme self.process_group.set_cpu_groups() dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group()) return grad @@ -191,11 +191,8 @@ class ZeroDDP(ColoDDP): For more details, see the API reference of ``GeminiManager``. """ - def __init__(self, - module: torch.nn.Module, - gemini_manager: GeminiManager, - process_group: Optional[ColoProcessGroup] = None) -> None: - super().__init__(module.half(), process_group=process_group) + def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None: + super().__init__(module.half(), process_group=gemini_manager.chunk_manager.process_group) self.gemini_manager = gemini_manager self.chunk_manager = gemini_manager.chunk_manager self.param_op_hook = ZeROHookV2(gemini_manager) diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index 640ff050e..12fba646d 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -171,3 +171,9 @@ class ProcessGroup: def cpu_tp_process_group(self): assert self._has_cpu_groups return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') + + def get_ranks_in_dp(self): + return self._dp_rank_list + + def get_ranks_in_tp(self): + return self._tp_rank_list diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index 52e755392..8789c18a6 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -33,11 +33,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP: - chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None - chunk_manager = ChunkManager(chunk_size) - gemini_manager = GeminiManager('cuda', chunk_manager) pg = ProcessGroup() - return ZeroDDP(module, gemini_manager, pg) + chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None + chunk_manager = ChunkManager(chunk_size, pg) + gemini_manager = GeminiManager('cuda', chunk_manager) + return ZeroDDP(module, gemini_manager) class Net(torch.nn.Module): diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index fc64f7796..121a8f44e 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -28,11 +28,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP: - chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None - chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) - gemini_manager = GeminiManager('cuda', chunk_manager) pg = ProcessGroup() - return ZeroDDP(module, gemini_manager, process_group=pg) + chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None + chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero) + gemini_manager = GeminiManager('cuda', chunk_manager) + return ZeroDDP(module, gemini_manager) def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): diff --git a/tests/test_tensor/test_chunk.py b/tests/test_tensor/test_chunk.py index 0f5d75c82..1f1b6e44b 100644 --- a/tests/test_tensor/test_chunk.py +++ b/tests/test_tensor/test_chunk.py @@ -7,8 +7,7 @@ from functools import partial from colossalai.gemini import ChunkManager from colossalai.testing import rerun_if_address_is_in_use, parameterize from colossalai.utils import free_port -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode +from colossalai.tensor import ProcessGroup as ColoProcessGroup def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]): @@ -38,12 +37,13 @@ TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, @parameterize('use_chunk', [False, True]) @parameterize('use_zero', [False, True]) def run_chunk_zero(use_chunk, use_zero): - rank = gpc.get_local_rank(ParallelMode.DATA) + pg = ColoProcessGroup() + rank = pg.rank() if rank == 0: print(f'use_chunk={use_chunk}, use_zero={use_zero}') params = [torch.rand(8, 8) for _ in range(3)] chunk_size = 128 if use_chunk else None - chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) + chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero) chunk_manager.create_group('param') assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cuda'] == 0 diff --git a/tests/test_tensor/test_zero_optim.py b/tests/test_tensor/test_zero_optim.py index 32f97d19a..ca5bf94e1 100644 --- a/tests/test_tensor/test_zero_optim.py +++ b/tests/test_tensor/test_zero_optim.py @@ -31,8 +31,6 @@ def check_param_equal(model, torch_model, pg: ProcessGroup): def check_grad_equal(model, torch_model, pg: ProcessGroup): for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): if p.grad is not None: - torch.distributed.barrier() - print(torch.distributed.get_rank(), p.grad) assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad, pg.tp_local_rank(), pg.tp_world_size()), \ f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}' @@ -63,9 +61,9 @@ def init_1d_col_spec(model, pg: ProcessGroup): p.set_tensor_spec(*spec) -@parameterize('use_chunk', [False]) -@parameterize('use_zero', [False]) -@parameterize('placement_policy', ['cuda']) +@parameterize('use_chunk', [False, True]) +@parameterize('use_zero', [False, True]) +@parameterize('placement_policy', ['cuda', 'cpu']) def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable('gpt2') @@ -92,10 +90,11 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_manager = ChunkManager(chunk_size, + pg, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy)) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pg) + model = ZeroDDP(model, gemini_manager) optim = HybridAdam(model.parameters(), lr=1e-3) optim = ZeroOptimizer(optim, model, initial_scale=1) @@ -104,7 +103,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) - # print(chunk_manager) + print(chunk_manager) check_param_equal(model, torch_model, pg) model.eval() @@ -129,13 +128,12 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') if world_size == 4: run_gpt(tp_init_spec_func=init_1d_col_spec) - # run_gpt(tp_init_spec_func=init_1d_row_spec) + run_gpt(tp_init_spec_func=init_1d_row_spec) else: run_gpt(tp_init_spec_func=init_1d_col_spec) @pytest.mark.dist -@pytest.mark.skip("buggy test") @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_gpt(world_size): diff --git a/tests/test_zero/test_zero_optim_state_dict.py b/tests/test_zero/test_zero_optim_state_dict.py index 28215f816..7ecb91795 100644 --- a/tests/test_zero/test_zero_optim_state_dict.py +++ b/tests/test_zero/test_zero_optim_state_dict.py @@ -20,13 +20,14 @@ from colossalai.tensor import ProcessGroup def init_zero(model, use_chunk, use_zero, placement_policy): + pg = ProcessGroup() chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_manager = ChunkManager(chunk_size, + pg, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy)) gemini_manager = GeminiManager(placement_policy, chunk_manager) - pg = ProcessGroup() - return ZeroDDP(model, gemini_manager, pg) + return ZeroDDP(model, gemini_manager) def run_step(model, optim, criterion, data, label):