[zero] support extra dp (#6123)

* [zero] support extra dp

* [zero] update checkpoint

* fix bugs

* fix bugs
This commit is contained in:
Hongxin Liu
2024-11-12 11:20:46 +08:00
committed by GitHub
parent 30a9443132
commit a2596519fd
8 changed files with 238 additions and 57 deletions

View File

@@ -2,7 +2,7 @@
import copy
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Dict, Iterator, List, Optional, Tuple, Union
from weakref import proxy
import torch
@@ -23,7 +23,15 @@ from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from ._utils import (
all_gather_into_flat_tensor_nd,
calculate_global_norm_from_list,
get_nd_rank,
get_nd_world_size,
has_inf_or_nan,
release_param_grad,
sync_tensor,
)
from .bookkeeping import BucketStore, GradientStore, TensorBucket
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
@@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(
self,
optimizer: Optimizer,
pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None,
pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
@@ -84,6 +92,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None,
extra_dp_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
@@ -98,9 +107,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if (dp_process_group is not None) and (pg_to_param_list is not None):
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None:
raise ValueError("dp_process_group should be provided when extra_dp_group is provided.")
if pg_to_param_list is None and extra_dp_group is not None and fp8_communication:
raise ValueError(
"fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided."
)
if pg_to_param_list is None:
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
if extra_dp_group is not None:
unique_dp_group = (extra_dp_group, unique_dp_group)
pg_to_param_list = {unique_dp_group: []}
for group in self.optim.param_groups:
pg_to_param_list[unique_dp_group].extend(group["params"])
@@ -336,10 +353,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads:
if self._fp8_communication:
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
else:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
for i, sz in enumerate(bucket_store.sizes):
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
if self._fp8_communication:
all_reduce_fp8(flat_grads, group=grp)
else:
dist.all_reduce(flat_grads, group=grp)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
@@ -347,16 +366,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
grad_in_bucket = bucket_store.get_grad()
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
received_grad = torch.zeros_like(flat_grads_list[0])
if self._fp8_communication:
reduce_scatter_fp8(
received_grad,
flat_grads_list,
group=bucket_store.torch_pg,
)
else:
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
cur_flat_grads = flat_grads
for i, sz in enumerate(bucket_store.sizes):
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz))
received_grad = torch.zeros_like(flat_grads_list[0])
if self._fp8_communication:
reduce_scatter_fp8(
received_grad,
flat_grads_list,
group=grp,
)
else:
dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp)
cur_flat_grads = received_grad
if received_grad.dtype != grad_dtype:
received_grad = received_grad.to(grad_dtype)
@@ -577,11 +600,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
pg = self.param_to_pg[working_param]
padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather:
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
# handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True)
set_all_gather_handle(working_param, handle)
else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
if self._fp8_communication:
# TODO: fit fp8 communication
all_gather_fp8(
list(padded_working_param.chunk(dist.get_world_size(pg))),
param_to_gather,
@@ -589,7 +614,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
fp8_format="e4m3",
)
else:
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
# dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg)
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
@@ -602,7 +628,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
def _compute_grad_norm(
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@@ -625,7 +653,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
if isinstance(dp_pg, tuple):
for grp in dp_pg:
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp)
else:
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
total_norm = total_norm_cuda.item()
else:
@@ -640,11 +672,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
torch.distributed.all_reduce(
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
group=dp_pg,
)
if isinstance(dp_pg, tuple):
for grp in dp_pg:
dist.all_reduce(
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
group=grp,
)
else:
torch.distributed.all_reduce(
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
group=dp_pg,
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
@@ -744,11 +784,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if isinstance(v, torch.Tensor) and k != "step":
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
dist.all_gather(gather_tensor, v.to(device), group=pg)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state)
@@ -770,15 +808,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
cnt += 1
for param_idx, state in zero_state_dict["state"].items():
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
world_size = get_nd_world_size(pg)
rank = get_nd_rank(pg)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
padding_size = (pg.size() - v.numel() % pg.size()) % pg.size()
padding_size = (world_size - v.numel() % world_size) % world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // pg.size())
zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone()
v_list = v.split(v.numel() // world_size)
zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
@@ -814,11 +854,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
dist.all_gather(state_tensor, v.to(device), group=pg)
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
current_block_size += state_tensor.numel()
current_block[k] = state_tensor
@@ -842,12 +880,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
p_id = id(p)
if p_id in self.working_to_master_param:
pg = self.param_to_pg[p]
world_size = get_nd_world_size(pg)
rank = get_nd_rank(pg)
master_param = self.working_to_master_param[p_id]
padding_size = self.get_param_padding_size(p)
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(pg.size())[pg.rank()])
master_param.copy_(working_param.chunk(world_size)[rank])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self.working_to_master_param
@@ -905,9 +945,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
grad = grad_store.get_working_grad_by_param_id(id(working_param))
if grad is None:
return None
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
grad_flat = grad.flatten()
output_grad = torch.empty(
grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype
)
all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg)
return output_grad.view(-1)[: working_param.numel()].view_as(working_param)
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
working_grads = []