[hotfix] ZeroDDP use new process group (#1333)

* process group supports getting ranks in group

* chunk mgr receives a process group

* update unit test

* fix unit tests
This commit is contained in:
ver217 2022-07-18 14:14:52 +08:00 committed by GitHub
parent 11d1436a67
commit 0c51ff2c13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 49 additions and 43 deletions

View File

@ -4,9 +4,8 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Dict, List from typing import Optional, Dict, List
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
class TensorState(Enum): class TensorState(Enum):
@ -65,14 +64,16 @@ class Chunk:
def __init__(self, def __init__(self,
chunk_size: int, chunk_size: int,
src_rank: int, src_rank: int,
process_group: ColoProcessGroup,
dtype: torch.dtype, dtype: torch.dtype,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
force_data_on_cuda: bool = False) -> None: force_data_on_cuda: bool = False) -> None:
self.size = chunk_size self.size = chunk_size
self.utilized_size = 0 self.utilized_size = 0
self.src_rank = src_rank self.src_rank = src_rank
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank self.process_group = process_group
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank] self.is_src_rank = process_group.dp_local_rank() == src_rank
self.global_src_rank = process_group.get_ranks_in_dp()[src_rank]
self.dtype = dtype self.dtype = dtype
device = init_device or get_current_device() device = init_device or get_current_device()
if force_data_on_cuda: if force_data_on_cuda:
@ -150,7 +151,7 @@ class Chunk:
if not self.is_src_rank: if not self.is_src_rank:
alloc_storage(self._payload) alloc_storage(self._payload)
self.move_device(get_current_device(), update_ptr=False) self.move_device(get_current_device(), update_ptr=False)
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
# update tensor meta info # update tensor meta info
self._update_tensors_ptr() self._update_tensors_ptr()
@ -193,9 +194,9 @@ class Chunk:
""" """
self.move_device(get_current_device(), update_ptr=False) self.move_device(get_current_device(), update_ptr=False)
if is_all_reduce: if is_all_reduce:
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) dist.all_reduce(self.data, group=self.process_group.dp_process_group())
else: else:
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
self._update_tensors_ptr() self._update_tensors_ptr()
self._update_tensors_state(TensorState.HOLD) self._update_tensors_state(TensorState.HOLD)
@ -216,7 +217,7 @@ class Chunk:
# invalid calls will be ignored and nothing changes # invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
# print( # print(
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' # f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# ) # )
return return
self.tensors_info[tensor].state = tensor_state self.tensors_info[tensor].state = tensor_state

View File

@ -2,9 +2,8 @@ import torch
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque from collections import deque
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .chunk import Chunk, ChunkFullError, TensorState from .chunk import Chunk, ChunkFullError, TensorState
@ -20,10 +19,13 @@ class ChunkManager:
def __init__(self, def __init__(self,
chunk_size: Optional[int], chunk_size: Optional[int],
process_group: ColoProcessGroup,
enable_distributed_storage: bool = False, enable_distributed_storage: bool = False,
init_device: Optional[torch.device] = None) -> None: init_device: Optional[torch.device] = None) -> None:
assert chunk_size is None or chunk_size > 0 assert chunk_size is None or chunk_size > 0
assert isinstance(process_group, ColoProcessGroup)
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.process_group = process_group
self.enable_distributed_storage = enable_distributed_storage self.enable_distributed_storage = enable_distributed_storage
self.device = init_device or get_current_device() self.device = init_device or get_current_device()
self.chunk_groups: Dict[str, Deque[Chunk]] = {} self.chunk_groups: Dict[str, Deque[Chunk]] = {}
@ -69,6 +71,7 @@ class ChunkManager:
src_rank = self._get_next_src_rank(group_name) src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size, chunk = Chunk(chunk_size,
src_rank, src_rank,
self.process_group,
tensor.dtype, tensor.dtype,
self.device, self.device,
force_data_on_cuda=self.groups_force_data_on_cuda[group_name]) force_data_on_cuda=self.groups_force_data_on_cuda[group_name])
@ -89,17 +92,17 @@ class ChunkManager:
def _get_next_src_rank(self, group_name: str) -> int: def _get_next_src_rank(self, group_name: str) -> int:
if not self.enable_distributed_storage: if not self.enable_distributed_storage:
# the chunk is owned by the current rank if no distributed storage is enabled # the chunk is owned by the current rank if no distributed storage is enabled
return gpc.get_local_rank(ParallelMode.DATA) return self.process_group.dp_local_rank()
if self.chunk_size is None: if self.chunk_size is None:
if group_name not in self.rank_load: if group_name not in self.rank_load:
self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64) self.rank_load[group_name] = torch.zeros(self.process_group.dp_world_size(), dtype=torch.int64)
# the process owning the tensor will be the process with the smallest number of elements # the process owning the tensor will be the process with the smallest number of elements
src_rank = torch.argmin(self.rank_load[group_name]).item() src_rank = torch.argmin(self.rank_load[group_name]).item()
else: else:
# chunk is owned by processes in a round-robin fashion # chunk is owned by processes in a round-robin fashion
chunk_idx = len(self.chunk_groups[group_name]) chunk_idx = len(self.chunk_groups[group_name])
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA) src_rank = chunk_idx % self.process_group.dp_world_size()
return src_rank return src_rank
def access_chunk(self, chunk: Chunk) -> None: def access_chunk(self, chunk: Chunk) -> None:
@ -222,7 +225,7 @@ class ChunkManager:
self.lazy_release_tensors.clear() self.lazy_release_tensors.clear()
def __repr__(self) -> str: def __repr__(self) -> str:
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n' msg = f'Rank {self.process_group.dp_local_rank()}:\n'
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
for group_name, group in self.chunk_groups.items(): for group_name, group in self.chunk_groups.items():
msg += f'Group {group_name}:\n' msg += f'Group {group_name}:\n'

View File

@ -118,7 +118,7 @@ class ColoDDP(torch.nn.Module):
return empty_grad return empty_grad
else: else:
#TODO(jiaruifang) fixme # TODO(jiaruifang) fixme
self.process_group.set_cpu_groups() self.process_group.set_cpu_groups()
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group()) dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
return grad return grad
@ -191,11 +191,8 @@ class ZeroDDP(ColoDDP):
For more details, see the API reference of ``GeminiManager``. For more details, see the API reference of ``GeminiManager``.
""" """
def __init__(self, def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
module: torch.nn.Module, super().__init__(module.half(), process_group=gemini_manager.chunk_manager.process_group)
gemini_manager: GeminiManager,
process_group: Optional[ColoProcessGroup] = None) -> None:
super().__init__(module.half(), process_group=process_group)
self.gemini_manager = gemini_manager self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager) self.param_op_hook = ZeROHookV2(gemini_manager)

View File

@ -171,3 +171,9 @@ class ProcessGroup:
def cpu_tp_process_group(self): def cpu_tp_process_group(self):
assert self._has_cpu_groups assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
def get_ranks_in_dp(self):
return self._dp_rank_list
def get_ranks_in_tp(self):
return self._tp_rank_list

View File

@ -33,11 +33,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size)
gemini_manager = GeminiManager('cuda', chunk_manager)
pg = ProcessGroup() pg = ProcessGroup()
return ZeroDDP(module, gemini_manager, pg) chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, pg)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)
class Net(torch.nn.Module): class Net(torch.nn.Module):

View File

@ -28,11 +28,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager('cuda', chunk_manager)
pg = ProcessGroup() pg = ProcessGroup()
return ZeroDDP(module, gemini_manager, process_group=pg) chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):

View File

@ -7,8 +7,7 @@ from functools import partial
from colossalai.gemini import ChunkManager from colossalai.gemini import ChunkManager
from colossalai.testing import rerun_if_address_is_in_use, parameterize from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.context import ParallelMode
def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]): def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]):
@ -38,12 +37,13 @@ TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512,
@parameterize('use_chunk', [False, True]) @parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True]) @parameterize('use_zero', [False, True])
def run_chunk_zero(use_chunk, use_zero): def run_chunk_zero(use_chunk, use_zero):
rank = gpc.get_local_rank(ParallelMode.DATA) pg = ColoProcessGroup()
rank = pg.rank()
if rank == 0: if rank == 0:
print(f'use_chunk={use_chunk}, use_zero={use_zero}') print(f'use_chunk={use_chunk}, use_zero={use_zero}')
params = [torch.rand(8, 8) for _ in range(3)] params = [torch.rand(8, 8) for _ in range(3)]
chunk_size = 128 if use_chunk else None chunk_size = 128 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
chunk_manager.create_group('param') chunk_manager.create_group('param')
assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0 assert chunk_manager.total_mem['cuda'] == 0

View File

@ -31,8 +31,6 @@ def check_param_equal(model, torch_model, pg: ProcessGroup):
def check_grad_equal(model, torch_model, pg: ProcessGroup): def check_grad_equal(model, torch_model, pg: ProcessGroup):
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
if p.grad is not None: if p.grad is not None:
torch.distributed.barrier()
print(torch.distributed.get_rank(), p.grad)
assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad, assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
pg.tp_local_rank(), pg.tp_world_size()), \ pg.tp_local_rank(), pg.tp_world_size()), \
f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}' f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}'
@ -63,9 +61,9 @@ def init_1d_col_spec(model, pg: ProcessGroup):
p.set_tensor_spec(*spec) p.set_tensor_spec(*spec)
@parameterize('use_chunk', [False]) @parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False]) @parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda']) @parameterize('placement_policy', ['cuda', 'cpu'])
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
@ -92,10 +90,11 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pg) model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3) optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=1) optim = ZeroOptimizer(optim, model, initial_scale=1)
@ -104,7 +103,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
# print(chunk_manager) print(chunk_manager)
check_param_equal(model, torch_model, pg) check_param_equal(model, torch_model, pg)
model.eval() model.eval()
@ -129,13 +128,12 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if world_size == 4: if world_size == 4:
run_gpt(tp_init_spec_func=init_1d_col_spec) run_gpt(tp_init_spec_func=init_1d_col_spec)
# run_gpt(tp_init_spec_func=init_1d_row_spec) run_gpt(tp_init_spec_func=init_1d_row_spec)
else: else:
run_gpt(tp_init_spec_func=init_1d_col_spec) run_gpt(tp_init_spec_func=init_1d_col_spec)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.skip("buggy test")
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gpt(world_size): def test_gpt(world_size):

View File

@ -20,13 +20,14 @@ from colossalai.tensor import ProcessGroup
def init_zero(model, use_chunk, use_zero, placement_policy): def init_zero(model, use_chunk, use_zero, placement_policy):
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
pg = ProcessGroup() return ZeroDDP(model, gemini_manager)
return ZeroDDP(model, gemini_manager, pg)
def run_step(model, optim, criterion, data, label): def run_step(model, optim, criterion, data, label):