mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[zero] support extra dp (#6123)
* [zero] support extra dp * [zero] update checkpoint * fix bugs * fix bugs
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import copy
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
from weakref import proxy
|
||||
|
||||
import torch
|
||||
@@ -23,7 +23,15 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from ._utils import (
|
||||
all_gather_into_flat_tensor_nd,
|
||||
calculate_global_norm_from_list,
|
||||
get_nd_rank,
|
||||
get_nd_world_size,
|
||||
has_inf_or_nan,
|
||||
release_param_grad,
|
||||
sync_tensor,
|
||||
)
|
||||
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
||||
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
|
||||
|
||||
@@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None,
|
||||
pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.0,
|
||||
@@ -84,6 +92,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
extra_dp_group: Optional[ProcessGroup] = None,
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
overlap_allgather: bool = False,
|
||||
@@ -98,9 +107,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
if (dp_process_group is not None) and (pg_to_param_list is not None):
|
||||
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
|
||||
if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None:
|
||||
raise ValueError("dp_process_group should be provided when extra_dp_group is provided.")
|
||||
if pg_to_param_list is None and extra_dp_group is not None and fp8_communication:
|
||||
raise ValueError(
|
||||
"fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided."
|
||||
)
|
||||
|
||||
if pg_to_param_list is None:
|
||||
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
|
||||
if extra_dp_group is not None:
|
||||
unique_dp_group = (extra_dp_group, unique_dp_group)
|
||||
pg_to_param_list = {unique_dp_group: []}
|
||||
for group in self.optim.param_groups:
|
||||
pg_to_param_list[unique_dp_group].extend(group["params"])
|
||||
@@ -336,10 +353,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
flat_grads = flat_grads.to(self._communication_dtype)
|
||||
|
||||
if not self._partition_grads:
|
||||
if self._fp8_communication:
|
||||
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
|
||||
else:
|
||||
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
|
||||
for i, sz in enumerate(bucket_store.sizes):
|
||||
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
|
||||
if self._fp8_communication:
|
||||
all_reduce_fp8(flat_grads, group=grp)
|
||||
else:
|
||||
dist.all_reduce(flat_grads, group=grp)
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
|
||||
@@ -347,16 +366,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
grad_in_bucket = bucket_store.get_grad()
|
||||
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
||||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
|
||||
received_grad = torch.zeros_like(flat_grads_list[0])
|
||||
if self._fp8_communication:
|
||||
reduce_scatter_fp8(
|
||||
received_grad,
|
||||
flat_grads_list,
|
||||
group=bucket_store.torch_pg,
|
||||
)
|
||||
else:
|
||||
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
|
||||
cur_flat_grads = flat_grads
|
||||
for i, sz in enumerate(bucket_store.sizes):
|
||||
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
|
||||
flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz))
|
||||
received_grad = torch.zeros_like(flat_grads_list[0])
|
||||
if self._fp8_communication:
|
||||
reduce_scatter_fp8(
|
||||
received_grad,
|
||||
flat_grads_list,
|
||||
group=grp,
|
||||
)
|
||||
else:
|
||||
dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp)
|
||||
cur_flat_grads = received_grad
|
||||
|
||||
if received_grad.dtype != grad_dtype:
|
||||
received_grad = received_grad.to(grad_dtype)
|
||||
@@ -577,11 +600,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
pg = self.param_to_pg[working_param]
|
||||
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
||||
if self._overlap_allgather:
|
||||
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
||||
# handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
||||
handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True)
|
||||
set_all_gather_handle(working_param, handle)
|
||||
else:
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
if self._fp8_communication:
|
||||
# TODO: fit fp8 communication
|
||||
all_gather_fp8(
|
||||
list(padded_working_param.chunk(dist.get_world_size(pg))),
|
||||
param_to_gather,
|
||||
@@ -589,7 +614,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
else:
|
||||
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||
# dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||
all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg)
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
@@ -602,7 +628,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
||||
|
||||
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
def _compute_grad_norm(
|
||||
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
|
||||
) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
@@ -625,7 +653,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=torch.float,
|
||||
)
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
|
||||
if isinstance(dp_pg, tuple):
|
||||
for grp in dp_pg:
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp)
|
||||
else:
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
||||
else:
|
||||
@@ -640,11 +672,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=torch.float,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=dp_pg,
|
||||
)
|
||||
if isinstance(dp_pg, tuple):
|
||||
for grp in dp_pg:
|
||||
dist.all_reduce(
|
||||
total_norm_exponentiated_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=grp,
|
||||
)
|
||||
else:
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=dp_pg,
|
||||
)
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
@@ -744,11 +784,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
working_param = self.master_to_working_param[id(param)]
|
||||
pg = self.param_to_pg[working_param]
|
||||
gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
|
||||
dist.all_gather(gather_tensor, v.to(device), group=pg)
|
||||
param_state = (
|
||||
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
||||
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
zero_state[param][k] = param_state
|
||||
|
||||
states_dict = self._pack_state(zero_state)
|
||||
@@ -770,15 +808,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
cnt += 1
|
||||
for param_idx, state in zero_state_dict["state"].items():
|
||||
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
|
||||
world_size = get_nd_world_size(pg)
|
||||
rank = get_nd_rank(pg)
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
padding_size = (pg.size() - v.numel() % pg.size()) % pg.size()
|
||||
padding_size = (world_size - v.numel() % world_size) % world_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // pg.size())
|
||||
zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone()
|
||||
v_list = v.split(v.numel() // world_size)
|
||||
zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
|
||||
@@ -814,11 +854,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
|
||||
dist.all_gather(state_tensor, v.to(device), group=pg)
|
||||
state_tensor = (
|
||||
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
@@ -842,12 +880,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
p_id = id(p)
|
||||
if p_id in self.working_to_master_param:
|
||||
pg = self.param_to_pg[p]
|
||||
world_size = get_nd_world_size(pg)
|
||||
rank = get_nd_rank(pg)
|
||||
master_param = self.working_to_master_param[p_id]
|
||||
padding_size = self.get_param_padding_size(p)
|
||||
working_param = p.data.view(-1)
|
||||
if padding_size > 0:
|
||||
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||
master_param.copy_(working_param.chunk(pg.size())[pg.rank()])
|
||||
master_param.copy_(working_param.chunk(world_size)[rank])
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self.working_to_master_param
|
||||
@@ -905,9 +945,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
grad = grad_store.get_working_grad_by_param_id(id(working_param))
|
||||
if grad is None:
|
||||
return None
|
||||
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
|
||||
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
|
||||
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
|
||||
grad_flat = grad.flatten()
|
||||
output_grad = torch.empty(
|
||||
grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype
|
||||
)
|
||||
all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg)
|
||||
return output_grad.view(-1)[: working_param.numel()].view_as(working_param)
|
||||
|
||||
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
||||
working_grads = []
|
||||
|
Reference in New Issue
Block a user