mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[zero] polish low level optimizer (#2473)
This commit is contained in:
@@ -103,7 +103,11 @@ def split_half_float_double(tensor_list):
|
||||
return buckets
|
||||
|
||||
|
||||
def reduce_tensor_dp_group(tensor, dtype=None, dst_rank=None, pg: Optional[ProcessGroup] = None):
|
||||
def reduce_tensor_dp_group(tensor: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dst_local_rank: Optional[int] = None,
|
||||
dst_global_rank: Optional[int] = None,
|
||||
group: Optional[dist.ProcessGroup] = None):
|
||||
"""
|
||||
Reduce the tensor in the data parallel process group
|
||||
|
||||
@@ -128,36 +132,22 @@ def reduce_tensor_dp_group(tensor, dtype=None, dst_rank=None, pg: Optional[Proce
|
||||
else:
|
||||
tensor_to_reduce = tensor
|
||||
|
||||
if isinstance(pg, ProcessGroup):
|
||||
group = pg.dp_process_group()
|
||||
world_size = pg.dp_world_size()
|
||||
else:
|
||||
world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
group = gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
tensor_to_reduce.div_(world_size)
|
||||
|
||||
# if rank is None, all reduce will be used
|
||||
# else, reduce is used
|
||||
use_all_reduce = dst_rank is None
|
||||
use_all_reduce = dst_local_rank is None
|
||||
|
||||
if use_all_reduce:
|
||||
dist.all_reduce(tensor_to_reduce, group=group)
|
||||
else:
|
||||
if pg is not None:
|
||||
ranks_in_group = pg.dp_rank_list()
|
||||
else:
|
||||
ranks_in_group = gpc.get_ranks_in_group(ParallelMode.DATA)
|
||||
global_rank = ranks_in_group[dst_rank]
|
||||
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
|
||||
dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)
|
||||
|
||||
# recover the original dtype
|
||||
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
||||
if pg is not None:
|
||||
local_rank = pg.dp_local_rank()
|
||||
else:
|
||||
local_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
if use_all_reduce or dst_rank == local_rank:
|
||||
local_rank = dist.get_rank(group=group)
|
||||
if use_all_reduce or dst_local_rank == local_rank:
|
||||
tensor.copy_(tensor_to_reduce)
|
||||
|
||||
return tensor
|
||||
|
@@ -1,19 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ProcessGroup
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class BaseStore:
|
||||
|
||||
def __init__(self, pg: Optional[ProcessGroup] = None):
|
||||
if isinstance(pg, ProcessGroup):
|
||||
self._world_size = pg.dp_world_size()
|
||||
self._local_rank = pg.dp_local_rank()
|
||||
else:
|
||||
self._world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
self._local_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
self._world_size = dist.get_world_size(group=torch_pg)
|
||||
self._local_rank = dist.get_rank(group=torch_pg)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
|
@@ -1,14 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class BucketStore(BaseStore):
|
||||
|
||||
def __init__(self, pg: Optional[ProcessGroup] = None):
|
||||
super().__init__(pg)
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
self._grads = dict()
|
||||
self._params = dict()
|
||||
self._num_elements_in_bucket = dict()
|
||||
|
@@ -1,16 +1,15 @@
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class ParameterStore(BaseStore):
|
||||
|
||||
def __init__(self, pg: Optional[ProcessGroup] = None):
|
||||
super().__init__(pg)
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
# param partitioning data structures
|
||||
self._fp16_param_to_rank = dict()
|
||||
self._rank_groupid_to_fp16_param_list = dict()
|
||||
|
@@ -10,7 +10,7 @@ from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
@@ -34,32 +34,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
pg: Optional[ProcessGroup] = None,
|
||||
# grad scaler config
|
||||
initial_scale=2**16,
|
||||
min_scale=1,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=2000,
|
||||
hysteresis=2,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.,
|
||||
backoff_factor: float = .5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
|
||||
# grad clipping
|
||||
clip_grad_norm=0.0,
|
||||
verbose=False,
|
||||
|
||||
# communication
|
||||
reduce_bucket_size=1024 * 1024,
|
||||
communication_dtype=None,
|
||||
overlap_communication=False,
|
||||
|
||||
# stage 2
|
||||
partition_grad=False,
|
||||
# cpu offload
|
||||
cpu_offload=False,
|
||||
|
||||
# forced dtype
|
||||
forced_dtype=None):
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
# TODO: add support for
|
||||
# 1. fp16 master weights
|
||||
@@ -76,16 +65,16 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
self._cpu_offload = cpu_offload
|
||||
|
||||
self._pg = pg
|
||||
if isinstance(pg, ProcessGroup):
|
||||
self._local_rank = pg.dp_local_rank()
|
||||
self._world_size = pg.dp_world_size()
|
||||
self._dp_group = pg.dp_process_group()
|
||||
if pg.tp_world_size() > 1:
|
||||
self._mp_group = pg.tp_process_group()
|
||||
else:
|
||||
self._mp_group = None
|
||||
elif pg is None:
|
||||
colo_pg = self._search_colo_process_group()
|
||||
if isinstance(colo_pg, ProcessGroup):
|
||||
self._local_rank = colo_pg.dp_local_rank()
|
||||
self._world_size = colo_pg.dp_world_size()
|
||||
self._dp_global_ranks = colo_pg.get_ranks_in_dp()
|
||||
self._dp_torch_group = colo_pg.dp_process_group()
|
||||
self._mp_torch_group = None
|
||||
if colo_pg.tp_world_size() > 1:
|
||||
self._mp_torch_group = colo_pg.tp_process_group()
|
||||
elif colo_pg is None:
|
||||
dp_parallel_mode = ParallelMode.DATA
|
||||
mp_parallel_mode = ParallelMode.MODEL
|
||||
|
||||
@@ -93,14 +82,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
self._mp_parallel_mode = mp_parallel_mode
|
||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
||||
|
||||
self._dp_group = gpc.get_group(dp_parallel_mode)
|
||||
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
|
||||
self._dp_torch_group = gpc.get_group(dp_parallel_mode)
|
||||
self._mp_torch_group = None
|
||||
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
|
||||
self._mp_group = gpc.get_group(mp_parallel_mode)
|
||||
else:
|
||||
self._mp_group = None
|
||||
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
|
||||
else:
|
||||
raise TypeError(f"pg should be None or a ProcesGroup")
|
||||
raise NotImplementedError
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
self._fp16_param_groups = dict()
|
||||
self._fp32_flat_param_groups_of_current_rank = dict()
|
||||
@@ -136,14 +124,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
# ParameterStore will manage the tensor buffers used for zero
|
||||
# it will not manage the tensors used by mixed precision training
|
||||
if self._pg is not None:
|
||||
self._param_store = ParameterStore(self._pg)
|
||||
self._grad_store = GradientStore(self._pg)
|
||||
self._bucket_store = BucketStore(self._pg)
|
||||
else:
|
||||
self._param_store = ParameterStore(self._dp_parallel_mode)
|
||||
self._grad_store = GradientStore(self._dp_parallel_mode)
|
||||
self._bucket_store = BucketStore(self._dp_parallel_mode)
|
||||
self._param_store = ParameterStore(self._dp_torch_group)
|
||||
self._grad_store = GradientStore(self._dp_torch_group)
|
||||
self._bucket_store = BucketStore(self._dp_torch_group)
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
@@ -224,6 +207,30 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
def num_param_groups(self):
|
||||
return len(self._fp16_param_groups)
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available(), 'CUDA is required'
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group['params']
|
||||
for param in group_params:
|
||||
assert param.dtype == self._dtype, \
|
||||
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
|
||||
def _search_colo_process_group(self):
|
||||
colo_flag = False
|
||||
colo_pg = None
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group['params']
|
||||
for param in group_params:
|
||||
if isinstance(param, ColoParameter):
|
||||
colo_flag = True
|
||||
if colo_pg is None:
|
||||
colo_pg = param.get_process_group()
|
||||
else:
|
||||
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
|
||||
elif colo_flag:
|
||||
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
|
||||
return colo_pg
|
||||
|
||||
def _partition_param_list(self, param_list):
|
||||
params_per_rank = [[] for _ in range(self._world_size)]
|
||||
numel_per_rank = [0 for _ in range(self._world_size)]
|
||||
@@ -241,14 +248,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
|
||||
return params_per_rank
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available(), 'CUDA is required'
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group['params']
|
||||
for param in group_params:
|
||||
assert param.dtype == self._dtype, \
|
||||
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
|
||||
###########################################################
|
||||
# Backward Reduction Hook
|
||||
###########################################################
|
||||
@@ -384,10 +383,14 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduce_global_rank = None
|
||||
if reduce_rank is not None:
|
||||
reduce_global_rank = self._dp_global_ranks[reduce_rank]
|
||||
reduced_flat = reduce_tensor_dp_group(tensor=flat,
|
||||
dtype=self._communication_dtype,
|
||||
dst_rank=reduce_rank,
|
||||
pg=self._pg)
|
||||
dst_local_rank=reduce_rank,
|
||||
dst_global_rank=reduce_global_rank,
|
||||
group=self._dp_torch_group)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._local_rank:
|
||||
@@ -456,8 +459,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
|
||||
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
|
||||
rank=self._local_rank),
|
||||
dp_group=self._dp_group,
|
||||
mp_group=self._mp_group)
|
||||
dp_group=self._dp_torch_group,
|
||||
mp_group=self._mp_torch_group)
|
||||
norm_groups.append(norm_group)
|
||||
|
||||
# create flat gradient for the flat fp32 params
|
||||
@@ -497,7 +500,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._world_size):
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
|
||||
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
@@ -519,11 +522,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
break
|
||||
|
||||
# all-reduce across dp group
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
|
||||
|
||||
# all-reduce over model parallel group
|
||||
if self._mp_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
|
||||
if self._mp_torch_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
|
||||
|
||||
if self._found_overflow.item() > 0:
|
||||
return True
|
||||
|
Reference in New Issue
Block a user