[moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer
This commit is contained in:
Hongxin Liu 2024-01-25 15:48:46 +08:00 committed by ver217
parent c904d2ae99
commit da39d21b71
14 changed files with 996 additions and 550 deletions

View File

@ -1,205 +1,617 @@
import copy
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.checkpoint_io import CheckpointIndexFile from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from colossalai.moe import MoECheckpintIO from colossalai.checkpoint_io.index_file import CheckpointIndexFile
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict_shards,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class MixtralMoECheckpointIO(MoECheckpintIO): class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
def __init__(self, *args, **kwargs): def __init__(
super().__init__(*args, **kwargs) self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
self.ep_group = moe_info.ep_group
self.ep_size = moe_info.ep_size
self.ep_rank = moe_info.ep_rank
self.real_dp_rank = moe_info.dp_rank
@torch.no_grad() @staticmethod
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict: def _model_sharder(
""" model: nn.Module,
Preprocess state_dict before loading and slice the state_dict of MOE tensors. prefix: str = "",
""" keep_vars: bool = False,
model_param_dict = dict(model.named_parameters()) size_per_shard: int = 1024,
for name, param in list(state_dict.items()): param_name_pattern: Optional[str] = None,
if ".gate.weight" in name: ) -> Iterator[Tuple[OrderedDict, int]]:
new_name = "module." + name.replace(".gate.weight", ".gate_weight") # An internel method that breaks state_dict of model into shards within limited size.
state_dict[new_name] = state_dict.pop(name)
elif ".experts." in name:
# if is moe tensor
# in our moe module, expert is cat as one tensor
# but mixtral's experts is not cat
# we will insert the loaded expert into the position of cat tensor
# get model param state_dict_sharder = StateDictSharder(size_per_shard)
str_idx = name.index(".experts.")
expert_idx = int(name.split(".")[-3]) # Save parameters.
if ".w1." in name: for name, param in model.named_parameters():
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate") if param is None:
elif ".w2." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wo")
elif ".w3." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
model_param_name = "module." + model_param_name
# skip for pipeline
if model_param_name not in model_param_dict:
continue continue
model_param = model_param_dict[model_param_name] if param_name_pattern is not None and param_name_pattern not in name:
assert is_moe_tensor(model_param) continue
# get expert range # Gather tensor pieces when using tensor parallel.
ep_rank = get_ep_rank(model_param) param_ = gather_distributed_param(param, keep_vars=False)
ep_size = get_ep_size(model_param) block, block_size = state_dict_sharder.append_param(prefix + name, param_)
expert_num = 8 // ep_size if block is not None:
expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num)) yield block, block_size
# insert new param
if expert_idx in expert_range:
new_param = model_param
new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
state_dict[model_param_name] = new_param
state_dict.pop(name)
else:
new_name = "module." + name
state_dict[new_name] = state_dict.pop(name)
dist.barrier() # Save buffers.
return state_dict for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): # Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
""" """
Load sharded model with the given path to index file of checkpoint folder. Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
Args: Args:
model (nn.Module): The model to be loaded. model (nn.Module): Model on local device to be saved.
checkpoint_index_file (str): Path to the index file of checkpointing folder. checkpoint (str): Checkpointing path which should be a directory path.
strict (bool, optional): For name matching during loading state_dict. Defaults to False. gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
This argument should be manually set to False since params on same device might be stored in different files. prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
""" """
# Check whether the checkpoint uses safetensors. assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
use_safetensors = False model = model.unwrap()
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available(): if os.path.isfile(checkpoint):
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
if self.real_dp_rank != 0:
return
# ep_rank 0 saves all the parameters and buffers.
# other ep_ranks save only experts
ep_param_pattern = "experts." if self.ep_rank != 0 else None
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
weights_name = weights_name.replace(
".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors"
)
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
dist.barrier(self.ep_group)
# The global master rank integrates the index files and clean the folder.
if self.coordinator.is_master():
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for weight, weight_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(weight, weight_filename)
final_index_file.write_index_file(final_index_file_path)
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
@staticmethod
def gather_from_sharded_optimizer_state(
state: OrderedDict,
param: torch.Tensor,
original_shape: torch.Size,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_zero: bool,
inplace: bool,
is_moe_param: bool,
device: torch.device = torch.device("cpu"),
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
Args:
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
dp_size = dist.get_world_size(dp_group)
tp_size = dist.get_world_size(tp_group)
current_shape = param.shape
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero and not is_moe_param:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
state_[k] = v.detach().clone().to(device)
return state_
@staticmethod
def _optimizer_sharder(
optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
size_per_shard: int = 1024,
only_moe_param: bool = False,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
master_to_working_map = optimizer.get_master_to_working_map()
for param, state in optimizer.optim.state.items():
if param is None:
continue
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info["param2id"][id(working_param)]
original_shape = param_info["param2shape"][id(working_param)]
state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
is_moe_param=is_moe_tensor(working_param),
)
if only_moe_param and not is_moe_tensor(working_param):
continue
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files that store state tensors of optimizers.
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.real_dp_rank != 0:
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
tp_group=self.tp_group,
size_per_shard=size_per_shard,
only_moe_param=self.ep_rank != 0,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
dist.barrier(self.ep_group)
# The global master rank integrates the index files and clean the folder.
if self.coordinator.is_master():
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for param_id, state_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(param_id, state_filename)
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
# Read checkpoint index file. # Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map weight_map = ckpt_index_file.weight_map
strict = False weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load params & buffers to model. # Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
# ep param groups
if len(optimizer.optim.param_groups) == len(saved_groups) + 1:
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded. # Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set() loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
def _load(name: str): # If this param's states has been loaded before, directly return.
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file: if filename in loaded_file:
return continue
file_path = os.path.join(ckpt_root_path, filename) file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
state_dict = self.pre_load_model(model, state_dict) load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
missing_keys = []
load_state_dict_into_model(
model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True,
)
loaded_file.add(filename) loaded_file.add(filename)
# Load parameters. # Then shard the loaded optimizer states if using tp/zero.
for name, _ in model.named_parameters(): for param, state in optimizer.optim.state.items():
name = name.replace("module.", "") device = param.device
name = name.replace(".gate_weight", ".gate.weight") if master_to_working_map is not None:
if ".experts.wi_gate" in name: working_param = master_to_working_map[id(param)]
for i in range(8):
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
_load(new_name)
elif ".experts.wi_up" in name:
for i in range(8):
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
_load(new_name)
elif ".experts.wo" in name:
for i in range(8):
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
_load(new_name)
else: else:
_load(name) working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
is_moe_param=is_moe_tensor(working_param),
)
optimizer.optim.state[param] = sharded_state
if self.verbose: sharded_optimizer_loading_epilogue(optimizer.optim)
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@torch.no_grad() def shard_from_complete_optimizer_state(
def pre_save_model(self, model: nn.Module) -> dict: self,
torch.cuda.empty_cache() state: OrderedDict,
state_dict = model.state_dict() current_shape: torch.Size,
for name, param in list(model.named_parameters()): original_shape: torch.Size,
if ".gate_weight" in name: device: torch.device,
new_name = name.replace(".gate_weight", ".gate.weight") inplace: bool,
state_dict[new_name] = state_dict.pop(name).cpu() is_moe_param: bool,
elif ".experts." in name: ) -> OrderedDict:
ep_group = get_ep_group(param) """
ep_rank = get_ep_rank(param) With complete optimizer states of a specific parameter loaded from checkpoint,
ep_size = get_ep_size(param) slice out the sharded optimizer states kept by current device.
dp_rank = get_dp_rank(param)
if dp_rank == 0: Args:
param = param.data.cuda() state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
all_param = [torch.zeros_like(param) for _ in range(ep_size)] current_shape (torch.Size): The size of parameter after sharding.
# gather param from every ep rank original_shape (torch.Size): The size of parameter before sharding.
dist.all_gather(all_param, param, group=ep_group) device (torch.device): The destination device of loaded optimizer states.
if ep_rank == 0: inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
all_param = torch.cat(all_param, dim=0)
assert all_param.shape[0] == 8
for i in range(8):
if ".wi_gate" in name:
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
elif ".wi_up" in name:
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
elif ".wo" in name:
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
new_name = new_name.replace("module.", "")
new_param = all_param[i].transpose(-1, -2)
state_dict[new_name] = new_param.cpu()
state_dict.pop(name)
else:
state_dict[name] = param.cpu()
for name, param in list(state_dict.items()): Returns:
new_name = name.replace("module.", "") OrderedDict: The sharded optimizer state of the given parameter.
state_dict[new_name] = state_dict.pop(name) """
state_ = state if inplace else copy.deepcopy(state)
torch.cuda.empty_cache() for k, v in state_.items():
if self.pp_size > 1: if isinstance(v, torch.Tensor) and k != "step":
if self.dp_rank == 0: # Shard state along tensor parallel group.
# gather state_dict from every pp rank partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
# because ckpt is large, we split it into 10 parts if partition_dim is not None:
# and gather them one by one slice_size = current_shape[partition_dim]
new_state_dict = {} v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
state_dict_keys = list(state_dict.keys())
gap_key_num = min(30, len(state_dict_keys)) # Shard state along data parallel group when using Zero.
gap_keys = (len(state_dict_keys) + gap_key_num - 1) // gap_key_num if self.use_zero and not is_moe_param:
for i in range(gap_key_num): padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys] with torch.no_grad():
cur_state_dict = {} v = v.flatten()
for k in cur_keys: if padding_size > 0:
cur_state_dict[k] = state_dict[k] v = torch.nn.functional.pad(v, [0, padding_size])
out = [None for _ in range(self.pp_size)] slice_size = v.numel() // self.dp_size
dist.all_gather_object(out, cur_state_dict, group=self.pp_group) v = v.split(slice_size, dim=0)[self.dp_rank]
if self.pp_rank == 0:
for o in out: state_[k] = v.detach().clone().to(device)
for k, v in o.items():
new_state_dict[k] = v.cpu() return state_
state_dict = new_state_dict
dist.barrier() def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
return state_dict raise NotImplementedError
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
raise NotImplementedError
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
raise NotImplementedError

View File

@ -1,80 +1,92 @@
import torch import torch
import torch.nn as nn import torch.distributed as dist
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe import SparseMLP from colossalai.moe import MOE_MANAGER
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
class MixtralSparseMLP: class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
r""" def __init__(self, config):
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. super().__init__(config)
""" self.setup_ep()
def __init__(self) -> None: def setup_ep(self):
raise NotImplementedError( _, moe_info = MOE_MANAGER.get_info(self.num_experts)
"FusedLayerNorm is not implemented as a physical class. " ep_group = moe_info.ep_group
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex." self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
) self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
set_moe_tensor_info(p, moe_info)
@staticmethod @staticmethod
def from_native_module(module: MixtralSparseMoeBlock, enable_kernel: bool) -> nn.Module: def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
r"""
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
with torch.no_grad():
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_ep()
return module
# get the attributes of the module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
moe_kwargs = dict( batch_size, sequence_length, hidden_dim = hidden_states.shape
num_experts=8, hidden_states = hidden_states.view(-1, hidden_dim)
hidden_size=module.hidden_dim, # router_logits: (batch * sequence_length, n_experts)
intermediate_size=module.ffn_dim, router_logits = self.gate(hidden_states)
router_top_k=module.top_k,
router_norm=True,
router_loss=False,
# router_capacity_factor_train=
# router_capacity_factor_eval=
mlp_activation="silu",
mlp_gated=True,
# enable_load_balance=
# load_balance_tolerance=
# load_balance_beam_width=
# load_balance_group_swap_factor=
enable_kernel=enable_kernel,
# enable_comm_overlap=
# enable_hierarchical_comm=
return_gate_logits=True,
)
dtype = module.gate.weight.dtype
device = module.gate.weight.device
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
return sparse_mlp routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
selected_experts = selected_experts.t().reshape(-1)
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
def replace_moe_layer(model: nn.Module, enable_kernel: bool = False) -> nn.Module: input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
""" output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
Reverse the replace layer operation output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
Args: output_states = MoeInGradScaler.apply(output_states, self.ep_size)
module (torch.nn.Module): The object of layer to shard if output_states.size(0) > 0:
""" if self.num_experts_per_ep == 1:
if isinstance(model, MixtralDecoderLayer): # no need to split
model.block_sparse_moe = MixtralSparseMLP.from_native_module( expert = self.experts[self.expert_start_idx]
model.block_sparse_moe, enable_kernel=enable_kernel output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
) output_states = expert.w2(output_states)
else: else:
for _, child in model.named_children(): output_states_splits = output_states.split(output_split_sizes.tolist())
replace_moe_layer(child, enable_kernel) output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
)
dispatch_states = dispatch_states[recover_experts_idx]
k_hidden_states = dispatch_states.chunk(self.top_k)
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
for i in range(1, self.top_k):
output_states += k_hidden_states[i] * routing_weights[:, i, None]
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
return output_states, router_logits

View File

@ -20,6 +20,8 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from .mixtral_layer import EPMixtralSparseMoeBlock
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] __all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
@ -51,6 +53,18 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(

View File

@ -3,7 +3,6 @@ import os
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
import torch import torch
from huggingface_hub import snapshot_download
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
@ -15,23 +14,6 @@ def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()} return {k: v.to(device) for k, v in batch.items()}
@torch.no_grad()
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
# pytorch ckpt
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
# saved ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
# download
else:
ckpt_path = snapshot_download(ckpt_path)
booster.load_model(model, ckpt_path)
if optimizer is not None:
optimizer.sync_moe_master_param()
optimizer.update_master_params(model)
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
""" """
Load file in JSON format Load file in JSON format
@ -90,7 +72,7 @@ def load_checkpoint(
""" """
# Update booster params states. # Update booster params states.
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer) booster.load_model(model, os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))

View File

@ -2,10 +2,8 @@ import argparse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_model
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@ -13,9 +11,6 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
def parse_args(): def parse_args():
@ -30,16 +25,10 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--plugin", "--plugin",
type=str, type=str,
default="hybrid", default="ep",
choices=["ep"], choices=["ep"],
help="Parallel methos.", help="Parallel methos.",
) )
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument( parser.add_argument(
"--precision", "--precision",
type=str, type=str,
@ -71,60 +60,38 @@ def main():
colossalai.launch_from_torch(config={}, seed=args.seed) colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator() coordinator = DistCoordinator()
config = MixtralConfig.from_pretrained(args.model_name)
ep_size = min(dist.get_world_size(), config.num_local_experts)
# Set plugin # Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"checkpoint_io": MixtralMoECheckpointIO,
"zero_stage": 1,
}
mgr_dict = {}
if args.plugin == "ep": if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1, pp_size=1,
**hybrid_dict, ep_size=ep_size,
) zero_stage=1,
MOE_MANAGER.setup( precision=args.precision,
parallel="EP", custom_policy=MixtralForCausalLMPolicy(),
max_ep_size=dp_size, checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
**mgr_dict, enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
) )
else: else:
raise ValueError(f"Invalid plugin {args.plugin}") raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build mixtral model # Build mixtral model
config = MixtralConfig.from_pretrained(args.model_name) model = MixtralForCausalLM.from_pretrained(args.model_name)
config.num_local_experts = 1 # dont change this. it will not affect model coordinator.print_on_master(f"Finish load model")
with skip_init():
model = MixtralForCausalLM(config)
model.num_experts = 8
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Replace moe
with skip_init():
replace_moe_layer(model)
model.eval()
coordinator.print_on_master(f"Finish replace moe module")
# Prepare tokenizer and dataloader # Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin)
model, _, _, _, _ = booster.boost(model=model) model, _, _, _, _ = booster.boost(model=model)
coordinator.print_on_master(f"Finish init booster") coordinator.print_on_master(f"Finish init booster")
# load ckpt model.eval()
load_model(args.model_name, model, booster)
coordinator.print_on_master(f"Finish load ckpt")
if coordinator.rank == 0: if coordinator.rank == 0:
text = ["Hello my name is"] text = ["Hello my name is"]
@ -132,10 +99,13 @@ def main():
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"] text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device()) inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
with torch.no_grad():
outputs = model.module.generate(**inputs, max_new_tokens=20) outputs = model.module.generate(**inputs, max_new_tokens=20)
outputs = tokenizer.batch_decode(outputs) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(f"[{coordinator.rank}] {outputs}") print(f"[{coordinator.rank}] {outputs}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -0,0 +1,63 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
from torch.testing import assert_close
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import colossalai
from colossalai.moe import MOE_MANAGER
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
MOE_MANAGER.setup(
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
)
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
)
torch.manual_seed(0)
orig_model = MixtralSparseMoeBlock(config).cuda()
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
model = EPMixtralSparseMoeBlock.from_native_module(model)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
orig_loss.backward()
ep_loss = ep_output.mean()
ep_loss.backward()
assert_close(orig_loss, ep_loss)
name_to_p = {n: p for n, p in orig_model.named_parameters()}
for n, ep_p in model.named_parameters():
p = name_to_p[n]
if ep_p.grad is not None:
assert_close(p.grad, ep_p.grad)
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch({}, rank, world_size, "localhost", port)
check_mixtral_moe_layer()
@pytest.mark.parametrize("world_size", [2, 4])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral_moe_layer(2)

View File

@ -1,185 +1,144 @@
import os from copy import deepcopy
import shutil
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from torch.optim import Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER from colossalai.testing.utils import spawn
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): def check_model_equal(model1, model2):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
attention_mask = torch.ones_like(input_ids) for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.equal(p1.half(), p2.half())
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
for group in optim.param_groups:
params = [id(p) for p in group["params"]]
new_group = {"params": params}
for k, v in group.items():
if k != "params":
new_group[k] = v
param_groups.append(new_group)
return { return {
"input_ids": input_ids, "state": state,
"attention_mask": attention_mask, "param_groups": param_groups,
"labels": input_ids,
} }
def run_fwd_bwd( def check_optimizer_snapshot_equal(snapshot1, snapshot2):
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None # check param_groups
): assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
model.train() for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
if pipeline: assert set(group1.keys()) == set(group2.keys())
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1) for k in group1.keys():
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() assert group1[k] == group2[k]
y = booster.execute_pipeline( # check state
train_dataloader_iter, assert set(snapshot1["state"].keys()) == set(
model, snapshot2["state"].keys()
lambda x, y: x.loss, ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
optimizer, for pid in snapshot1["state"].keys():
return_loss=True, state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
return_outputs=True, assert set(state1.keys()) == set(state2.keys())
) for k in state1.keys():
# Backward and optimize if isinstance(state1[k], torch.Tensor):
if is_pp_last_stage: assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
loss = y["loss"]
else: else:
if criterion: assert state1[k] == state2[k]
y = model(data).logits
loss = criterion(y)
else:
loss = model(data, label)
loss = loss.float()
if optimizer is not None:
optimizer.backward(loss)
else:
loss.backward()
return y
def get_config(): def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
config = MixtralConfig( config = MixtralConfig(
vocab_size=300, hidden_size=hidden_size,
hidden_size=32, intermediate_size=hidden_size * 2,
intermediate_size=16, num_local_experts=n_experts,
num_hidden_layers=2, num_experts_per_tok=top_k,
dropout_rate=0.0, num_attention_heads=2,
num_key_value_heads=2,
) )
return config torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
orig_model = MixtralForCausalLM(config).cuda()
def get_model(parallel): model = deepcopy(orig_model)
config = get_config() optimizer = Adam(model.parameters(), lr=1e-3)
model = MixtralForCausalLM(config).to(torch.bfloat16) plugin = MoeHybridParallelPlugin(
replace_moe_layer(model)
optim = torch.optim.Adam(model.parameters())
args = dict(
precision="bf16",
tp_size=1, tp_size=1,
zero_stage=1,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoECheckpointIO,
)
if parallel == "ep":
plugin = MoeHybridParallelPlugin(
pp_size=1,
**args,
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
pp_size=2, pp_size=2,
ep_size=2,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
microbatch_size=1, microbatch_size=1,
**args, zero_stage=1,
) )
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim) model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
return model, booster, optim # initialize grads
data_iter = iter(
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
)
booster.execute_pipeline(
data_iter,
model,
lambda outputs, inputs: outputs.loss,
optimizer,
)
# check save model
def _test_moe_checkpoint(parallel): booster.save_model(model, "mixtral_model", shard=True)
dist.barrier()
if dist.get_rank() == 0: if dist.get_rank() == 0:
if os.path.exists("./tmp_ckpt1"): saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
shutil.rmtree("./tmp_ckpt1") check_model_equal(orig_model, saved_model)
if os.path.exists("./tmp_ckpt2"): saved_model.save_pretrained("mixtral_hf_model")
shutil.rmtree("./tmp_ckpt2")
dist.barrier() dist.barrier()
if parallel == None: # check load model
MOE_MANAGER.setup( new_model = MixtralForCausalLM(config).cuda()
parallel=None, new_optimizer = Adam(new_model.parameters(), lr=1e-3)
) new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
elif parallel == "ep": booster.load_model(new_model, "mixtral_hf_model")
MOE_MANAGER.setup( check_model_equal(model, new_model)
parallel="EP",
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
# param ckpt
# check not equal
try:
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
raise AssertionError("state_dict should not be equal")
except:
pass
# shard
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt1")
# check
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
# optim ckpt # check save optimizer
criterion = lambda x: x.mean() optimizer.step()
data = torch.randint(0, 4, (2, 4)).cuda() snapshot = get_optimizer_snapshot(optimizer.unwrap())
label = torch.randint(0, 4, (2,)).cuda() booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
if parallel == "hybrid":
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
else:
kwargs = {}
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
optim1.step()
optim1.zero_grad()
# shard
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
dist.barrier() dist.barrier()
booster2.load_optimizer(optim2, "./tmp_ckpt2") # reset optimizer state
# check for state in optimizer.unwrap().state.values():
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False) for v in state.values():
if isinstance(v, torch.Tensor):
if dist.get_rank() == 0: v.zero_()
shutil.rmtree("./tmp_ckpt1") booster.load_optimizer(optimizer, "mixtral_optim")
shutil.rmtree("./tmp_ckpt2") loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
def _run_dist(rank, world_size, port, parallel): def run_dist(rank: int, world_size: int, port: int):
colossalai.launch( colossalai.launch({}, rank, world_size, "localhost", port)
config=dict(), check_mixtral_moe_layer()
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
_test_moe_checkpoint(parallel)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", ["ep", "hybrid"]) def test_mixtral_moe_layer(world_size: int):
@rerun_if_address_is_in_use() spawn(run_dist, world_size)
def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="hybrid") test_mixtral_moe_layer(4)

View File

@ -1,31 +0,0 @@
import copy
import torch
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
class Config:
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_act = hidden_act
def test_moe_layer():
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
mistral_moe = MixtralSparseMoeBlock(config).cuda()
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
data = torch.randn(2, 8, 4).cuda()
mistral_output = mistral_moe(data)[0]
colossal_output = colossal_moe(data)[0]
assert torch.allclose(
mistral_output, colossal_output
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
if __name__ == "__main__":
test_moe_layer()

View File

@ -2,22 +2,18 @@ import argparse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralForCausalLM
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -153,45 +149,27 @@ def main():
coordinator = DistCoordinator() coordinator = DistCoordinator()
# Set plugin # Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"zero_stage": args.zero_stage,
"checkpoint_io": MixtralMoECheckpointIO,
}
mgr_dict = {}
if args.plugin == "hybrid": if args.plugin == "hybrid":
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=args.pp_size, pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size, microbatch_size=args.microbatch_size,
**hybrid_dict, custom_policy=MixtralForCausalLMPolicy(),
) enable_fused_normalization=args.use_layernorm_kernel,
MOE_MANAGER.setup( enable_jit_fused=args.use_kernel,
parallel="EP", precision=args.precision,
mode="fixed", zero_stage=args.zero_stage,
fixed_dp_size=args.dp_size, checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
) )
else: else:
raise ValueError(f"Invalid plugin {args.plugin}") raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build Mixtral model # Build Mixtral model
config = MixtralConfig.from_pretrained(args.model_name) model = MixtralForCausalLM.from_pretrained(args.model_name)
config.use_cache = False coordinator.print_on_master(f"Finish init model")
config.num_local_experts = 1
model = MixtralForCausalLM(config)
model.num_experts = 8
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
replace_moe_layer(model, enable_kernel=args.use_kernel)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing # Enable gradient checkpointing
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
@ -224,7 +202,7 @@ def main():
) )
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin)
model, optimizer, _, dataloader, lr_scheduler = booster.boost( model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
@ -236,10 +214,7 @@ def main():
coordinator.print_on_master(f"Finish init booster") coordinator.print_on_master(f"Finish init booster")
# Load ckpt # Load ckpt
if args.load_checkpoint is None: if args.load_checkpoint is not None:
load_model(args.model_name, model, booster, optimizer)
coordinator.print_on_master(f"Finish load checkpoint")
else:
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler) load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
coordinator.print_on_master(f"Finish load optimizer") coordinator.print_on_master(f"Finish load optimizer")
@ -286,13 +261,13 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
# Apply load balance # Apply load balance
if ( # if (
args.load_balance # args.load_balance
and args.load_balance_interval > 0 # and args.load_balance_interval > 0
and (step + 1) % args.load_balance_interval == 0 # and (step + 1) % args.load_balance_interval == 0
): # ):
coordinator.print_on_master(f"Apply load balance") # coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer) # apply_load_balance(model, optimizer)
# save ckeckpoint # save ckeckpoint
if (step + 1) % args.save_interval == 0: if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")

View File

@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
) )
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoECheckpintIO from colossalai.moe import MOE_MANAGER, MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self, self,
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
ep_size: int,
extra_dp_size: int = 1, extra_dp_size: int = 1,
precision: str = "fp16", precision: str = "fp16",
zero_stage: int = 0, zero_stage: int = 0,
@ -189,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if enable_sequence_parallelism: if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=self.real_dp_size,
fixed_ep_size=ep_size,
fixed_pp_size=pp_size,
use_ep_inside=use_ep_inside,
)
self.tp_size = tp_size self.tp_size = tp_size
self.pp_size = pp_size self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.ep_size = ep_size
self.moe_info = MOE_MANAGER.get_info(0)[1]
self.precision = precision self.precision = precision
self.zero_stage = zero_stage self.zero_stage = zero_stage
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload

View File

@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from .utils import has_index_file from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
__all__ = ["CheckpointIO"] __all__ = ["CheckpointIO"]
@ -89,6 +89,14 @@ class CheckpointIO(ABC):
if index_file_exists: if index_file_exists:
self.load_sharded_model(model, index_file_path, strict) self.load_sharded_model(model, index_file_path, strict)
else:
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else: else:
self.load_unsharded_model(model, checkpoint, strict) self.load_unsharded_model(model, checkpoint, strict)

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
if ctx.ep_size != 1: if ctx.ep_size != 1:
grad = grad / ctx.ep_size grad = grad / ctx.ep_size
return grad, None return grad, None
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
outputs_shape = list(inputs.shape)
if output_split_sizes is not None:
outputs_shape[0] = sum(output_split_sizes)
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle
class AllToAllUneven(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inputs,
input_split_sizes=None,
output_split_sizes=None,
group=None,
overlap: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
@staticmethod
def backward(ctx: Any, *grad_outputs):
return (
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
None,
None,
None,
None,
)
def all_to_all_uneven(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
):
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)

View File

@ -26,3 +26,5 @@ class MoeParallelInfo:
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
self.dp_group = self.pg.get_group_along_axis(self.dp_axis) self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
self.ep_rank = self.pg.coordinate(self.ep_axis)
self.dp_rank = self.pg.coordinate(self.dp_axis)

View File

@ -666,10 +666,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
def sync_moe_master_param(self):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r""" r"""
Compute and return the gradient norm for gradient clipping. Compute and return the gradient norm for gradient clipping.
@ -915,9 +911,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else: else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.master_to_working_param return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}