[moe] merge moe into main (#4978)

* update moe module
* support openmoe
This commit is contained in:
Xuanlei Zhao
2023-11-02 10:21:24 +08:00
committed by GitHub
parent 8993c8a817
commit dc003c304c
67 changed files with 7618 additions and 1657 deletions

View File

@@ -8,6 +8,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor, inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
@@ -18,6 +19,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device
@@ -75,6 +77,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
forced_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
master_weights: bool = True, # master weights
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@@ -95,6 +98,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg)
# extra dp
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
# And moe working and master param are split by extra dp pg.
self.moe_extra_dp_pg = moe_extra_dp_process_group
if self.moe_extra_dp_pg is not None:
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict()
@@ -126,6 +139,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
self._bucket_store = BucketStore(self.dp_pg)
# moe param should not be stored in working_groups
# because they have different parallel strategy
# so we need to store them separately in param_groups
# instead of working_groups
moe_params = list()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
@@ -133,6 +152,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
group_params = list()
for param in param_group["params"]:
if param.requires_grad:
if self.moe_extra_dp_pg is None:
# skip moe param
if is_moe_tensor(param):
moe_params.append(param)
continue
group_params.append(param)
# add the working params to working_param_groups for bookkeeping
@@ -146,6 +170,15 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank
param_group["params"] = master_param_current_rank
# if there are moe params, store in addtional group in optim
if len(moe_params) > 0:
param_group = dict()
for key, value in self.optim.param_groups[0].items():
if key != "params":
param_group[key] = value
param_group["params"] = moe_params
self.optim.param_groups.append(param_group)
# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
@@ -208,13 +241,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // self._world_size)
if self.moe_extra_dp_pg is not None and is_moe_tensor(param):
splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size)
splited_params = splited_params[self.moe_extra_dp_pg_rank]
else:
splited_params = padding_param.split(padding_param.numel() // self._world_size)
splited_params = splited_params[self._local_rank]
# use fp32 when master_weights is True
if self._master_weights is True:
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
splited_param_current_rank = splited_params.detach().float().to(device)
else:
splited_param_current_rank = splited_params[self._local_rank]
splited_param_current_rank = splited_params
params_current_rank.append(splited_param_current_rank)
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
@@ -247,8 +287,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self._bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket()
flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size
if self.moe_extra_dp_pg is None:
flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size
else:
# record moe and non moe param
moe_list = []
for param in self._bucket_store._param_list:
moe_list.append(is_moe_tensor(param))
# divide them into different groups
moe_grad_list = []
non_moe_grad_list = []
for grad_list in self._bucket_store._grad_in_bucket.values():
non_moe_cur_grad = []
moe_cur_grad = []
for i in range(len(grad_list)):
if moe_list[i] == True:
moe_cur_grad.append(grad_list[i])
else:
non_moe_cur_grad.append(grad_list[i])
if len(moe_cur_grad) > 0:
moe_grad_list.append(moe_cur_grad)
if len(non_moe_cur_grad) > 0:
non_moe_grad_list.append(non_moe_cur_grad)
if len(non_moe_grad_list) > 0:
non_moe_flat_grads = []
for grad_list in non_moe_grad_list:
non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
non_moe_flat_grads /= self._world_size
if len(moe_grad_list) > 0:
moe_flat_grads = []
for grad_list in moe_grad_list:
moe_flat_grads.append(_flatten_dense_tensors(grad_list))
moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
# ready to add other tensors to bucket
self._bucket_store.reset_num_elements_in_bucket()
@@ -256,7 +331,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self._overlap_communication:
stream = self._comm_stream
# in case of the memory being reused in the default stream
flat_grads.record_stream(stream)
if self.moe_extra_dp_pg is None:
flat_grads.record_stream(stream)
else:
if len(non_moe_grad_list) > 0:
non_moe_flat_grads.record_stream(stream)
if len(moe_grad_list) > 0:
moe_flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing
stream.wait_stream(torch.cuda.current_stream())
else:
@@ -265,49 +346,108 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
with torch.cuda.stream(stream):
group_id = self._bucket_store.current_group_id
grad_dtype = flat_grads.dtype
if self._communication_dtype is not None:
flat_grads = flat_grads.to(self._communication_dtype)
if self.moe_extra_dp_pg is None:
grad_dtype = flat_grads.dtype
if self._communication_dtype is not None:
flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads:
dist.all_reduce(flat_grads, group=self.dp_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
if self.moe_extra_dp_pg is None:
dist.all_reduce(flat_grads, group=self.dp_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
grad_in_bucket = self._bucket_store.get_grad()
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
grad_in_bucket = self._bucket_store.get_grad()
self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id)
for rank, grad_list in grad_in_bucket.items():
sync_tensor(flat_grads_per_rank[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
if (
len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id))
< self._world_size
):
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
# sync extra zero group
else:
# sync non moe param in global dp group
if len(non_moe_grad_list) > 0:
dist.all_reduce(non_moe_flat_grads, group=self.dp_pg)
flat_grads_per_rank = non_moe_flat_grads.split(
non_moe_flat_grads.numel() // self._world_size
)
self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)
# sync moe param only in zero group
if len(moe_grad_list) > 0:
dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg)
flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size)
self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
if self.moe_extra_dp_pg is None:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
for grad in grad_in_bucket_current_rank:
param_id = self._bucket_store.get_param_id_of_grad(grad)
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
else:
# categorize moe and non moe param
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
moe_grad_in_bucket_current_rank = []
non_moe_grad_in_bucket_current_rank = []
for idx, grad in enumerate(grad_in_bucket_current_rank):
if moe_list[idx] == True:
moe_grad_in_bucket_current_rank.append(grad)
else:
non_moe_grad_in_bucket_current_rank.append(grad)
if len(non_moe_grad_list) > 0:
flat_grads_list = list(
non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
self._update_partitoned_grad(
non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
)
if len(moe_grad_list) > 0:
flat_grads_list = list(
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
param_slice = self._world_size // self.moe_extra_dp_pg_size
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
for split_recieved_grad in recieved_grad:
split_recieved_grad = _unflatten_dense_tensors(
split_recieved_grad, moe_grad_in_bucket_current_rank
)
for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(real_grad, param_slice, group_id, param_id)
self._bucket_store.reset()
def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None:
for rank, grad_list in enumerate(origin_grad_list):
sync_tensor(flat_grad_list[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, self._world_size, group_id, param_id, rank)
def _update_partitoned_grad(
self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
) -> None:
sync_tensor(flat_grad, origin_grad_list)
for grad in origin_grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, partition_num, group_id, param_id)
def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
def _add_to_bucket(self, param, group_id):
param_size = param.numel()
@@ -424,13 +564,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# else the splited grad should be attached to the splited param
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
if len(grads) > 0:
real_working_params[group_id].append(working_param)
# moe hybrid zero
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
real_working_params[group_id].append(working_param)
if self._partition_grads:
grad = grads
else:
param_slice = self._world_size // self.moe_extra_dp_pg_size
grad = grads[
self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice
]
grad = flatten(grad)
else:
real_working_params[group_id].append(working_param)
grad = grads[grad_index]
# no need to copy fp32 grad if master_weights is False
grad = (
grads[grad_index].to(splited_param.dtype).to(splited_param.device)
if self._master_weights
else grads[grad_index]
)
if self._master_weights:
grad = grad.to(splited_param.dtype).to(splited_param.device)
splited_param.grad = grad
grad_partition_groups.append(grad)
real_master_params[group_id].append(splited_param)
@@ -449,24 +599,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# TODO: we should store master param for ep
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.data = param.data.to(torch.float32)
param.grad = param.grad.to(torch.float32)
# update the parameters
self.optim.step()
# release the moe gradm
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.grad = None
param.data = param.data.to(self._dtype)
# release the grad
grad_partition_groups = []
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank
# dtype = real_working_params[0][0].dtype
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg)
else:
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
@@ -488,7 +657,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
@@ -596,10 +764,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
working_param = self._param_store.master_to_working_param[id(param)]
gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg)
else:
gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@@ -624,8 +798,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
v_list = v.split(v.numel() // self.moe_extra_dp_pg_size)
zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone()
else:
v_list = v.split(v.numel() // self._world_size)
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
@@ -656,8 +834,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg)
else:
state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@@ -688,7 +874,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
if self.moe_extra_dp_pg is not None and is_moe_tensor(p):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param