mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
[zero] support extra dp (#6123)
* [zero] support extra dp * [zero] update checkpoint * fix bugs * fix bugs
This commit is contained in:
parent
30a9443132
commit
a2596519fd
@ -29,6 +29,7 @@ from colossalai.checkpoint_io.utils import (
|
|||||||
save_state_dict,
|
save_state_dict,
|
||||||
sharded_optimizer_loading_epilogue,
|
sharded_optimizer_loading_epilogue,
|
||||||
)
|
)
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
||||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
|
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -358,11 +360,16 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
cast_inputs: bool = True,
|
cast_inputs: bool = True,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
|
extra_dp_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||||
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
|
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
|
||||||
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
|
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
|
||||||
|
if extra_dp_size > 1:
|
||||||
|
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
|
||||||
|
inner_dp_size = dist.get_world_size() // extra_dp_size
|
||||||
|
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.zero_optim_kwargs = dict(
|
self.zero_optim_kwargs = dict(
|
||||||
@ -383,6 +390,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
overlap_allgather=overlap_allgather,
|
overlap_allgather=overlap_allgather,
|
||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
)
|
)
|
||||||
|
if extra_dp_size > 1:
|
||||||
|
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
|
||||||
|
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
|
||||||
self.lora_enabled = False
|
self.lora_enabled = False
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
|
|||||||
# update the tensor data
|
# update the tensor data
|
||||||
for p, q in zip(tensor_list, updated_params):
|
for p, q in zip(tensor_list, updated_params):
|
||||||
p.data = q.data
|
p.data = q.data
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_into_flat_tensor_nd(
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
|
||||||
|
async_op: bool = False,
|
||||||
|
):
|
||||||
|
if isinstance(group, dist.ProcessGroup):
|
||||||
|
group = (group,)
|
||||||
|
sizes = [dist.get_world_size(pg) for pg in group]
|
||||||
|
ranks = [dist.get_rank(pg) for pg in group]
|
||||||
|
for i, pg in list(enumerate(group))[::-1]:
|
||||||
|
if i == 0:
|
||||||
|
out = output_tensor
|
||||||
|
else:
|
||||||
|
prev_sizes = sizes[:i]
|
||||||
|
prev_ranks = ranks[:i]
|
||||||
|
chunks = output_tensor.chunk(np.prod(prev_sizes))
|
||||||
|
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
|
||||||
|
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
|
||||||
|
input_tensor = out
|
||||||
|
return handle
|
||||||
|
|
||||||
|
|
||||||
|
def get_nd_world_size(group) -> int:
|
||||||
|
if isinstance(group, tuple):
|
||||||
|
return int(np.prod([dist.get_world_size(pg) for pg in group]))
|
||||||
|
else:
|
||||||
|
return dist.get_world_size(group)
|
||||||
|
|
||||||
|
|
||||||
|
def get_nd_rank(group) -> int:
|
||||||
|
if isinstance(group, tuple):
|
||||||
|
return np.ravel_multi_index(
|
||||||
|
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return dist.get_rank(group)
|
||||||
|
@ -1,11 +1,20 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
class BaseStore:
|
class BaseStore:
|
||||||
def __init__(self, torch_pg: ProcessGroup):
|
def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]):
|
||||||
self._world_size = dist.get_world_size(group=torch_pg)
|
if isinstance(torch_pg, tuple):
|
||||||
self._local_rank = dist.get_rank(group=torch_pg)
|
self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg]
|
||||||
|
self._world_size = int(np.prod(self.sizes))
|
||||||
|
self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes)
|
||||||
|
else:
|
||||||
|
self._world_size = dist.get_world_size(group=torch_pg)
|
||||||
|
self._local_rank = dist.get_rank(group=torch_pg)
|
||||||
|
self.sizes = [self._world_size]
|
||||||
self.torch_pg = torch_pg
|
self.torch_pg = torch_pg
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
from colossalai.quantization.fp8 import all_gather_fp8
|
from colossalai.quantization.fp8 import all_gather_fp8
|
||||||
|
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
|
||||||
|
|
||||||
|
|
||||||
class TensorBucket:
|
class TensorBucket:
|
||||||
@ -65,12 +67,18 @@ class TensorBucket:
|
|||||||
|
|
||||||
def all_gather(self, group=None, fp8_communication: bool = False):
|
def all_gather(self, group=None, fp8_communication: bool = False):
|
||||||
flat = self.flatten()
|
flat = self.flatten()
|
||||||
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
|
if isinstance(group, tuple):
|
||||||
|
world_size = np.prod([dist.get_world_size(pg) for pg in group])
|
||||||
|
else:
|
||||||
|
world_size = dist.get_world_size(group)
|
||||||
|
buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
|
||||||
if fp8_communication:
|
if fp8_communication:
|
||||||
|
# TODO: fit fp8
|
||||||
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
|
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
|
||||||
else:
|
else:
|
||||||
dist.all_gather_into_tensor(buffer, flat, group=group)
|
# dist.all_gather_into_tensor(buffer, flat, group=group)
|
||||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
|
all_gather_into_flat_tensor_nd(buffer, flat, group=group)
|
||||||
|
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)]
|
||||||
# transpose the list of list
|
# transpose the list of list
|
||||||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Iterator, List, Optional, Tuple
|
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||||
from weakref import proxy
|
from weakref import proxy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -23,7 +23,15 @@ from colossalai.logging import get_dist_logger
|
|||||||
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
|
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
|
||||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||||
|
|
||||||
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
from ._utils import (
|
||||||
|
all_gather_into_flat_tensor_nd,
|
||||||
|
calculate_global_norm_from_list,
|
||||||
|
get_nd_rank,
|
||||||
|
get_nd_world_size,
|
||||||
|
has_inf_or_nan,
|
||||||
|
release_param_grad,
|
||||||
|
sync_tensor,
|
||||||
|
)
|
||||||
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
||||||
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
|
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
|
||||||
|
|
||||||
@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None,
|
pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None,
|
||||||
initial_scale: int = 2**16, # grad scaler config
|
initial_scale: int = 2**16, # grad scaler config
|
||||||
min_scale: int = 1,
|
min_scale: int = 1,
|
||||||
growth_factor: float = 2.0,
|
growth_factor: float = 2.0,
|
||||||
@ -84,6 +92,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
partition_grad: bool = False, # stage 2 flag
|
partition_grad: bool = False, # stage 2 flag
|
||||||
cpu_offload: bool = False, # cpu offload
|
cpu_offload: bool = False, # cpu offload
|
||||||
dp_process_group: Optional[ProcessGroup] = None,
|
dp_process_group: Optional[ProcessGroup] = None,
|
||||||
|
extra_dp_group: Optional[ProcessGroup] = None,
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
master_weights: bool = True, # master weights
|
master_weights: bool = True, # master weights
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
@ -98,9 +107,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
if (dp_process_group is not None) and (pg_to_param_list is not None):
|
if (dp_process_group is not None) and (pg_to_param_list is not None):
|
||||||
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
|
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
|
||||||
|
if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None:
|
||||||
|
raise ValueError("dp_process_group should be provided when extra_dp_group is provided.")
|
||||||
|
if pg_to_param_list is None and extra_dp_group is not None and fp8_communication:
|
||||||
|
raise ValueError(
|
||||||
|
"fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided."
|
||||||
|
)
|
||||||
|
|
||||||
if pg_to_param_list is None:
|
if pg_to_param_list is None:
|
||||||
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
|
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
|
||||||
|
if extra_dp_group is not None:
|
||||||
|
unique_dp_group = (extra_dp_group, unique_dp_group)
|
||||||
pg_to_param_list = {unique_dp_group: []}
|
pg_to_param_list = {unique_dp_group: []}
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
pg_to_param_list[unique_dp_group].extend(group["params"])
|
pg_to_param_list[unique_dp_group].extend(group["params"])
|
||||||
@ -336,10 +353,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
flat_grads = flat_grads.to(self._communication_dtype)
|
flat_grads = flat_grads.to(self._communication_dtype)
|
||||||
|
|
||||||
if not self._partition_grads:
|
if not self._partition_grads:
|
||||||
if self._fp8_communication:
|
for i, sz in enumerate(bucket_store.sizes):
|
||||||
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
|
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
|
||||||
else:
|
if self._fp8_communication:
|
||||||
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
|
all_reduce_fp8(flat_grads, group=grp)
|
||||||
|
else:
|
||||||
|
dist.all_reduce(flat_grads, group=grp)
|
||||||
if flat_grads.dtype != grad_dtype:
|
if flat_grads.dtype != grad_dtype:
|
||||||
flat_grads = flat_grads.to(grad_dtype)
|
flat_grads = flat_grads.to(grad_dtype)
|
||||||
|
|
||||||
@ -347,16 +366,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
grad_in_bucket = bucket_store.get_grad()
|
grad_in_bucket = bucket_store.get_grad()
|
||||||
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
||||||
else:
|
else:
|
||||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
|
cur_flat_grads = flat_grads
|
||||||
received_grad = torch.zeros_like(flat_grads_list[0])
|
for i, sz in enumerate(bucket_store.sizes):
|
||||||
if self._fp8_communication:
|
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
|
||||||
reduce_scatter_fp8(
|
flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz))
|
||||||
received_grad,
|
received_grad = torch.zeros_like(flat_grads_list[0])
|
||||||
flat_grads_list,
|
if self._fp8_communication:
|
||||||
group=bucket_store.torch_pg,
|
reduce_scatter_fp8(
|
||||||
)
|
received_grad,
|
||||||
else:
|
flat_grads_list,
|
||||||
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
|
group=grp,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp)
|
||||||
|
cur_flat_grads = received_grad
|
||||||
|
|
||||||
if received_grad.dtype != grad_dtype:
|
if received_grad.dtype != grad_dtype:
|
||||||
received_grad = received_grad.to(grad_dtype)
|
received_grad = received_grad.to(grad_dtype)
|
||||||
@ -577,11 +600,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
pg = self.param_to_pg[working_param]
|
pg = self.param_to_pg[working_param]
|
||||||
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
||||||
if self._overlap_allgather:
|
if self._overlap_allgather:
|
||||||
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
# handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
||||||
|
handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True)
|
||||||
set_all_gather_handle(working_param, handle)
|
set_all_gather_handle(working_param, handle)
|
||||||
else:
|
else:
|
||||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||||
if self._fp8_communication:
|
if self._fp8_communication:
|
||||||
|
# TODO: fit fp8 communication
|
||||||
all_gather_fp8(
|
all_gather_fp8(
|
||||||
list(padded_working_param.chunk(dist.get_world_size(pg))),
|
list(padded_working_param.chunk(dist.get_world_size(pg))),
|
||||||
param_to_gather,
|
param_to_gather,
|
||||||
@ -589,7 +614,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
fp8_format="e4m3",
|
fp8_format="e4m3",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
# dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||||
|
all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg)
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||||
@ -602,7 +628,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
if not tensor_bucket.is_empty():
|
if not tensor_bucket.is_empty():
|
||||||
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
||||||
|
|
||||||
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
|
def _compute_grad_norm(
|
||||||
|
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
|
||||||
|
) -> float:
|
||||||
r"""
|
r"""
|
||||||
Compute and return the gradient norm for gradient clipping.
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
|
||||||
@ -625,7 +653,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
device=get_accelerator().get_current_device(),
|
device=get_accelerator().get_current_device(),
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
)
|
)
|
||||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
|
if isinstance(dp_pg, tuple):
|
||||||
|
for grp in dp_pg:
|
||||||
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp)
|
||||||
|
else:
|
||||||
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
|
||||||
total_norm = total_norm_cuda.item()
|
total_norm = total_norm_cuda.item()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -640,11 +672,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
device=get_accelerator().get_current_device(),
|
device=get_accelerator().get_current_device(),
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
)
|
)
|
||||||
torch.distributed.all_reduce(
|
if isinstance(dp_pg, tuple):
|
||||||
total_norm_exponentiated_cuda,
|
for grp in dp_pg:
|
||||||
op=torch.distributed.ReduceOp.SUM,
|
dist.all_reduce(
|
||||||
group=dp_pg,
|
total_norm_exponentiated_cuda,
|
||||||
)
|
op=torch.distributed.ReduceOp.SUM,
|
||||||
|
group=grp,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
total_norm_exponentiated_cuda,
|
||||||
|
op=torch.distributed.ReduceOp.SUM,
|
||||||
|
group=dp_pg,
|
||||||
|
)
|
||||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||||
|
|
||||||
return total_norm
|
return total_norm
|
||||||
@ -744,11 +784,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
working_param = self.master_to_working_param[id(param)]
|
working_param = self.master_to_working_param[id(param)]
|
||||||
pg = self.param_to_pg[working_param]
|
pg = self.param_to_pg[working_param]
|
||||||
gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
|
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||||
dist.all_gather(gather_tensor, v.to(device), group=pg)
|
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
||||||
param_state = (
|
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||||
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
|
||||||
)
|
|
||||||
zero_state[param][k] = param_state
|
zero_state[param][k] = param_state
|
||||||
|
|
||||||
states_dict = self._pack_state(zero_state)
|
states_dict = self._pack_state(zero_state)
|
||||||
@ -770,15 +808,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
cnt += 1
|
cnt += 1
|
||||||
for param_idx, state in zero_state_dict["state"].items():
|
for param_idx, state in zero_state_dict["state"].items():
|
||||||
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
|
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
|
||||||
|
world_size = get_nd_world_size(pg)
|
||||||
|
rank = get_nd_rank(pg)
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
padding_size = (pg.size() - v.numel() % pg.size()) % pg.size()
|
padding_size = (world_size - v.numel() % world_size) % world_size
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
v = v.flatten()
|
v = v.flatten()
|
||||||
if padding_size > 0:
|
if padding_size > 0:
|
||||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||||
v_list = v.split(v.numel() // pg.size())
|
v_list = v.split(v.numel() // world_size)
|
||||||
zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone()
|
zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone()
|
||||||
|
|
||||||
self.optim.load_state_dict(zero_state_dict)
|
self.optim.load_state_dict(zero_state_dict)
|
||||||
|
|
||||||
@ -814,11 +854,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
for k, v in states.items():
|
for k, v in states.items():
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
|
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||||
dist.all_gather(state_tensor, v.to(device), group=pg)
|
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
||||||
state_tensor = (
|
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||||
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
|
||||||
)
|
|
||||||
current_block_size += state_tensor.numel()
|
current_block_size += state_tensor.numel()
|
||||||
current_block[k] = state_tensor
|
current_block[k] = state_tensor
|
||||||
|
|
||||||
@ -842,12 +880,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
p_id = id(p)
|
p_id = id(p)
|
||||||
if p_id in self.working_to_master_param:
|
if p_id in self.working_to_master_param:
|
||||||
pg = self.param_to_pg[p]
|
pg = self.param_to_pg[p]
|
||||||
|
world_size = get_nd_world_size(pg)
|
||||||
|
rank = get_nd_rank(pg)
|
||||||
master_param = self.working_to_master_param[p_id]
|
master_param = self.working_to_master_param[p_id]
|
||||||
padding_size = self.get_param_padding_size(p)
|
padding_size = self.get_param_padding_size(p)
|
||||||
working_param = p.data.view(-1)
|
working_param = p.data.view(-1)
|
||||||
if padding_size > 0:
|
if padding_size > 0:
|
||||||
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||||
master_param.copy_(working_param.chunk(pg.size())[pg.rank()])
|
master_param.copy_(working_param.chunk(world_size)[rank])
|
||||||
|
|
||||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||||
return self.working_to_master_param
|
return self.working_to_master_param
|
||||||
@ -905,9 +945,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
grad = grad_store.get_working_grad_by_param_id(id(working_param))
|
grad = grad_store.get_working_grad_by_param_id(id(working_param))
|
||||||
if grad is None:
|
if grad is None:
|
||||||
return None
|
return None
|
||||||
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
|
grad_flat = grad.flatten()
|
||||||
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
|
output_grad = torch.empty(
|
||||||
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
|
grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype
|
||||||
|
)
|
||||||
|
all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg)
|
||||||
|
return output_grad.view(-1)[: working_param.numel()].view_as(working_param)
|
||||||
|
|
||||||
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
||||||
working_grads = []
|
working_grads = []
|
||||||
|
42
tests/test_zero/test_low_level/test_coll_nd.py
Normal file
42
tests/test_zero/test_low_level/test_coll_nd.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
from colossalai.testing.random import seed_all
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
|
||||||
|
|
||||||
|
|
||||||
|
def check_all_gather_2d():
|
||||||
|
seed_all(1024)
|
||||||
|
tensor = torch.rand(128, device=get_current_device())
|
||||||
|
extra_dp_size, inner_dp_size = 2, 2
|
||||||
|
pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
|
||||||
|
extra_dp_group = pg_mesh.get_group_along_axis(0)
|
||||||
|
inner_dp_group = pg_mesh.get_group_along_axis(1)
|
||||||
|
ranks = [dist.get_rank(extra_dp_group), dist.get_rank(inner_dp_group)]
|
||||||
|
sizes = [dist.get_world_size(extra_dp_group), dist.get_world_size(inner_dp_group)]
|
||||||
|
chunk = tensor.chunk(dist.get_world_size())[np.ravel_multi_index(ranks, sizes)].clone()
|
||||||
|
out = torch.zeros_like(tensor)
|
||||||
|
all_gather_into_flat_tensor_nd(out, chunk, group=(extra_dp_group, inner_dp_group))
|
||||||
|
assert torch.equal(out, tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
|
||||||
|
check_all_gather_2d()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_comm_nd():
|
||||||
|
spawn(run_dist, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_comm_nd()
|
@ -2,11 +2,13 @@ import copy
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
@ -123,7 +125,8 @@ def exam_zero_1_2(fp8_communication: bool):
|
|||||||
|
|
||||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@parameterize("master_weights", [True, False])
|
@parameterize("master_weights", [True, False])
|
||||||
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
@parameterize("extra_dp_size", [1, 2])
|
||||||
|
def exam_zero_1_torch_ddp(dtype: torch.dtype, master_weights: bool, extra_dp_size: int):
|
||||||
"""
|
"""
|
||||||
In this test, two pairs of model and optimizers are created.
|
In this test, two pairs of model and optimizers are created.
|
||||||
1. zero: use sharded optimizer and fp16 parameters
|
1. zero: use sharded optimizer and fp16 parameters
|
||||||
@ -132,6 +135,15 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||||||
We feed these two sets of models with the same input and check if the
|
We feed these two sets of models with the same input and check if the
|
||||||
differences in model output and updated parameters are within tolerance.
|
differences in model output and updated parameters are within tolerance.
|
||||||
"""
|
"""
|
||||||
|
if extra_dp_size > 1 and dtype != torch.bfloat16:
|
||||||
|
return
|
||||||
|
if extra_dp_size > 1:
|
||||||
|
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
|
||||||
|
extra_dp_group = pg_mesh.get_group_along_axis(0)
|
||||||
|
dp_group = pg_mesh.get_group_along_axis(1)
|
||||||
|
else:
|
||||||
|
extra_dp_group = None
|
||||||
|
dp_group = None
|
||||||
local_rank = torch.distributed.get_rank()
|
local_rank = torch.distributed.get_rank()
|
||||||
seed_all(1453)
|
seed_all(1453)
|
||||||
|
|
||||||
@ -153,6 +165,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||||||
initial_scale=1,
|
initial_scale=1,
|
||||||
reduce_bucket_size=1024 * 1024,
|
reduce_bucket_size=1024 * 1024,
|
||||||
master_weights=master_weights,
|
master_weights=master_weights,
|
||||||
|
dp_process_group=dp_group,
|
||||||
|
extra_dp_group=extra_dp_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||||
@ -200,14 +214,14 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
|
||||||
exam_zero_1_torch_ddp(world_size=world_size)
|
exam_zero_1_torch_ddp()
|
||||||
exam_zero_1_2()
|
exam_zero_1_2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_zero_1_2():
|
def test_zero_1_2():
|
||||||
spawn(run_dist, 2)
|
spawn(run_dist, 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -2,12 +2,14 @@ import copy
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
|
||||||
@ -40,11 +42,19 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
|||||||
assert_close(a, b, rtol=rtol, atol=atol)
|
assert_close(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
def exam_zero_1_torch_ddp_ckpt():
|
@parameterize("extra_dp_size", [1, 2])
|
||||||
|
def exam_zero_1_torch_ddp_ckpt(extra_dp_size: int):
|
||||||
"""
|
"""
|
||||||
We examine the state_dict of zero and DDP.
|
We examine the state_dict of zero and DDP.
|
||||||
Moreover, we examine the zero's loading checkpoint of a torch ckpt.
|
Moreover, we examine the zero's loading checkpoint of a torch ckpt.
|
||||||
"""
|
"""
|
||||||
|
if extra_dp_size > 1:
|
||||||
|
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
|
||||||
|
extra_dp_group = pg_mesh.get_group_along_axis(0)
|
||||||
|
dp_group = pg_mesh.get_group_along_axis(1)
|
||||||
|
else:
|
||||||
|
dp_group = None
|
||||||
|
extra_dp_group = None
|
||||||
local_rank = torch.distributed.get_rank()
|
local_rank = torch.distributed.get_rank()
|
||||||
seed_all(1453)
|
seed_all(1453)
|
||||||
|
|
||||||
@ -60,7 +70,12 @@ def exam_zero_1_torch_ddp_ckpt():
|
|||||||
# we only test stage 1 here
|
# we only test stage 1 here
|
||||||
# the state dicts of stage 1 and stage 2 are the same
|
# the state dicts of stage 1 and stage 2 are the same
|
||||||
zero_optimizer = LowLevelZeroOptimizer(
|
zero_optimizer = LowLevelZeroOptimizer(
|
||||||
zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144
|
zero_optimizer,
|
||||||
|
overlap_communication=True,
|
||||||
|
initial_scale=1,
|
||||||
|
reduce_bucket_size=262144,
|
||||||
|
dp_process_group=dp_group,
|
||||||
|
extra_dp_group=extra_dp_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||||
@ -111,7 +126,7 @@ def run_dist(rank, world_size, port):
|
|||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_zero_ckpt():
|
def test_zero_ckpt():
|
||||||
spawn(run_dist, 2)
|
spawn(run_dist, 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user