mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
feat: add sub_dp_group
This commit is contained in:
parent
1aaa453706
commit
9291f07964
@ -5,6 +5,7 @@ from contextlib import contextmanager
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Iterator, List, Optional, Tuple
|
from typing import Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -80,6 +81,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
partition_grad: bool = False, # stage 2 flag
|
partition_grad: bool = False, # stage 2 flag
|
||||||
cpu_offload: bool = False, # cpu offload
|
cpu_offload: bool = False, # cpu offload
|
||||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||||
|
sub_dp_size: int = 1, # further divide zero into sub-dp groups and zero groups
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
master_weights: bool = True, # master weights
|
master_weights: bool = True, # master weights
|
||||||
):
|
):
|
||||||
@ -102,10 +104,37 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
self.require_grad_sync = True
|
self.require_grad_sync = True
|
||||||
|
|
||||||
# if process_group is none, will use the default one
|
# if process_group is none, will use the default one
|
||||||
self.dp_pg = dp_process_group
|
if dp_process_group is None:
|
||||||
|
dp_process_group = dist.group.WORLD
|
||||||
|
assert dist.get_world_size(group=dp_process_group) % sub_dp_size == 0
|
||||||
|
dp_ranks = dist.get_process_group_ranks(group=dp_process_group)
|
||||||
|
dp_ranks = np.array(dp_ranks).reshape(sub_dp_size, -1)
|
||||||
|
sub_dp_rank = dist.get_rank(group=dp_process_group) % dp_ranks.shape[1]
|
||||||
|
zero_rank = dist.get_rank(group=dp_process_group) // dp_ranks.shape[1]
|
||||||
|
|
||||||
|
if sub_dp_size == 1:
|
||||||
|
self.dp_pg = dp_process_group
|
||||||
|
else:
|
||||||
|
self.dp_pg = None
|
||||||
|
for i in range(dp_ranks.shape[0]):
|
||||||
|
group = dist.new_group(dp_ranks[i])
|
||||||
|
if i == zero_rank:
|
||||||
|
assert self.dp_pg is None
|
||||||
|
self.dp_pg = group
|
||||||
self._local_rank = dist.get_rank(group=self.dp_pg)
|
self._local_rank = dist.get_rank(group=self.dp_pg)
|
||||||
self._world_size = dist.get_world_size(group=self.dp_pg)
|
self._world_size = dist.get_world_size(group=self.dp_pg)
|
||||||
|
|
||||||
|
self.sub_dp_pg = None
|
||||||
|
if sub_dp_size > 1:
|
||||||
|
for i in range(dp_ranks.shape[1]):
|
||||||
|
group = dist.new_group(dp_ranks[:, i])
|
||||||
|
if i == sub_dp_rank:
|
||||||
|
assert self.sub_dp_pg is None
|
||||||
|
self.sub_dp_pg = group
|
||||||
|
if self.sub_dp_pg is not None:
|
||||||
|
self._sub_dp_rank = dist.get_rank(group=self.sub_dp_pg)
|
||||||
|
self._sub_dp_world_size = dist.get_world_size(group=self.sub_dp_pg)
|
||||||
|
|
||||||
# working and master params for mixed precision training
|
# working and master params for mixed precision training
|
||||||
self._working_param_groups = dict()
|
self._working_param_groups = dict()
|
||||||
self._master_param_groups_of_current_rank = dict()
|
self._master_param_groups_of_current_rank = dict()
|
||||||
@ -285,6 +314,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
if not self._partition_grads:
|
if not self._partition_grads:
|
||||||
dist.all_reduce(flat_grads, group=self.dp_pg)
|
dist.all_reduce(flat_grads, group=self.dp_pg)
|
||||||
|
if self.sub_dp_pg is not None:
|
||||||
|
dist.all_reduce(flat_grads, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
|
||||||
|
|
||||||
if flat_grads.dtype != grad_dtype:
|
if flat_grads.dtype != grad_dtype:
|
||||||
flat_grads = flat_grads.to(grad_dtype)
|
flat_grads = flat_grads.to(grad_dtype)
|
||||||
|
|
||||||
@ -296,6 +328,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||||
|
if self.sub_dp_pg is not None:
|
||||||
|
dist.all_reduce(recieved_grad, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
|
||||||
|
|
||||||
if recieved_grad.dtype != grad_dtype:
|
if recieved_grad.dtype != grad_dtype:
|
||||||
recieved_grad = recieved_grad.to(grad_dtype)
|
recieved_grad = recieved_grad.to(grad_dtype)
|
||||||
@ -498,6 +532,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
# HACK: torch optim would skip tensor whose grad is None
|
# HACK: torch optim would skip tensor whose grad is None
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
real_master_params[group_id][idx].grad = None
|
real_master_params[group_id][idx].grad = None
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|
||||||
if not is_first_step:
|
if not is_first_step:
|
||||||
# update working partition updated by the current rank
|
# update working partition updated by the current rank
|
||||||
@ -516,6 +551,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
working_param.data.copy_(
|
working_param.data.copy_(
|
||||||
flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)
|
flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)
|
||||||
)
|
)
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|
||||||
# release the grad
|
# release the grad
|
||||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||||
@ -544,7 +580,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
total_norm_cuda = torch.tensor(
|
total_norm_cuda = torch.tensor(
|
||||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
|
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||||
)
|
)
|
||||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_pg)
|
||||||
|
if self.sub_dp_pg is not None:
|
||||||
|
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.sub_dp_pg)
|
||||||
total_norm = total_norm_cuda.item()
|
total_norm = total_norm_cuda.item()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -557,9 +595,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
total_norm_exponentiated_cuda = torch.tensor(
|
total_norm_exponentiated_cuda = torch.tensor(
|
||||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
|
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||||
)
|
)
|
||||||
torch.distributed.all_reduce(
|
dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||||
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
if self.sub_dp_pg is not None:
|
||||||
)
|
dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
|
||||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||||
|
|
||||||
return total_norm
|
return total_norm
|
||||||
|
Loading…
Reference in New Issue
Block a user