feat: add sub_dp_group

This commit is contained in:
Wenhao Chen 2024-04-01 14:51:06 +08:00 committed by アマデウス
parent 1aaa453706
commit 9291f07964

View File

@ -5,6 +5,7 @@ from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple from typing import Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
@ -80,6 +81,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
partition_grad: bool = False, # stage 2 flag partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
sub_dp_size: int = 1, # further divide zero into sub-dp groups and zero groups
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights master_weights: bool = True, # master weights
): ):
@ -102,10 +104,37 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.require_grad_sync = True self.require_grad_sync = True
# if process_group is none, will use the default one # if process_group is none, will use the default one
self.dp_pg = dp_process_group if dp_process_group is None:
dp_process_group = dist.group.WORLD
assert dist.get_world_size(group=dp_process_group) % sub_dp_size == 0
dp_ranks = dist.get_process_group_ranks(group=dp_process_group)
dp_ranks = np.array(dp_ranks).reshape(sub_dp_size, -1)
sub_dp_rank = dist.get_rank(group=dp_process_group) % dp_ranks.shape[1]
zero_rank = dist.get_rank(group=dp_process_group) // dp_ranks.shape[1]
if sub_dp_size == 1:
self.dp_pg = dp_process_group
else:
self.dp_pg = None
for i in range(dp_ranks.shape[0]):
group = dist.new_group(dp_ranks[i])
if i == zero_rank:
assert self.dp_pg is None
self.dp_pg = group
self._local_rank = dist.get_rank(group=self.dp_pg) self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg) self._world_size = dist.get_world_size(group=self.dp_pg)
self.sub_dp_pg = None
if sub_dp_size > 1:
for i in range(dp_ranks.shape[1]):
group = dist.new_group(dp_ranks[:, i])
if i == sub_dp_rank:
assert self.sub_dp_pg is None
self.sub_dp_pg = group
if self.sub_dp_pg is not None:
self._sub_dp_rank = dist.get_rank(group=self.sub_dp_pg)
self._sub_dp_world_size = dist.get_world_size(group=self.sub_dp_pg)
# working and master params for mixed precision training # working and master params for mixed precision training
self._working_param_groups = dict() self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict() self._master_param_groups_of_current_rank = dict()
@ -285,6 +314,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if not self._partition_grads: if not self._partition_grads:
dist.all_reduce(flat_grads, group=self.dp_pg) dist.all_reduce(flat_grads, group=self.dp_pg)
if self.sub_dp_pg is not None:
dist.all_reduce(flat_grads, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
if flat_grads.dtype != grad_dtype: if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype) flat_grads = flat_grads.to(grad_dtype)
@ -296,6 +328,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0]) recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
if self.sub_dp_pg is not None:
dist.all_reduce(recieved_grad, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
if recieved_grad.dtype != grad_dtype: if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype) recieved_grad = recieved_grad.to(grad_dtype)
@ -498,6 +532,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# HACK: torch optim would skip tensor whose grad is None # HACK: torch optim would skip tensor whose grad is None
self.optim.step() self.optim.step()
real_master_params[group_id][idx].grad = None real_master_params[group_id][idx].grad = None
torch.cuda.current_stream().synchronize()
if not is_first_step: if not is_first_step:
# update working partition updated by the current rank # update working partition updated by the current rank
@ -516,6 +551,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param.data.copy_( working_param.data.copy_(
flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param) flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)
) )
torch.cuda.current_stream().synchronize()
# release the grad # release the grad
release_param_grad(self._master_param_groups_of_current_rank[group_id]) release_param_grad(self._master_param_groups_of_current_rank[group_id])
@ -544,7 +580,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
total_norm_cuda = torch.tensor( total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
) )
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_pg)
if self.sub_dp_pg is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.sub_dp_pg)
total_norm = total_norm_cuda.item() total_norm = total_norm_cuda.item()
else: else:
@ -557,9 +595,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
total_norm_exponentiated_cuda = torch.tensor( total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
) )
torch.distributed.all_reduce( dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg if self.sub_dp_pg is not None:
) dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm return total_norm