mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
Merge pull request #5372 from hpcaitech/exp/mixtral
This commit is contained in:
BIN
applications/ColossalMoE/README.md
Normal file
BIN
applications/ColossalMoE/README.md
Normal file
Binary file not shown.
0
applications/ColossalMoE/colossal_moe/__init__.py
Normal file
0
applications/ColossalMoE/colossal_moe/__init__.py
Normal file
@@ -0,0 +1,629 @@
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
|
||||
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
StateDictSharder,
|
||||
gather_distributed_param,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
load_shard_state_dict,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
|
||||
moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
|
||||
self.ep_group = moe_info.ep_group
|
||||
self.ep_size = moe_info.ep_size
|
||||
self.ep_rank = moe_info.ep_rank
|
||||
self.real_dp_rank = moe_info.dp_rank
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(
|
||||
model: nn.Module,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024,
|
||||
param_name_pattern: Optional[str] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
|
||||
# Save parameters.
|
||||
for name, param in model.named_parameters():
|
||||
if param is None:
|
||||
continue
|
||||
if param_name_pattern is not None and param_name_pattern not in name:
|
||||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Save buffers.
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Save extra states.
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if (
|
||||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = model.get_extra_state()
|
||||
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
|
||||
- Multiple files that store state tensors of models.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
|
||||
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model on local device to be saved.
|
||||
checkpoint (str): Checkpointing path which should be a directory path.
|
||||
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
|
||||
prefix (str, optional): Perfix of file to save. Defaults to None.
|
||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model = model.unwrap()
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.real_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
# ep_rank 0 saves all the parameters and buffers.
|
||||
# other ep_ranks save only experts
|
||||
ep_param_pattern = "experts." if self.ep_rank != 0 else None
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
|
||||
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
|
||||
)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
|
||||
weights_name = weights_name.replace(
|
||||
".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors"
|
||||
)
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.coordinator.is_master():
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for weight, weight_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(weight, weight_filename)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
save_config_file(model, checkpoint)
|
||||
rmtree(tmp_index_file_folder)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gather_from_sharded_optimizer_state(
|
||||
state: OrderedDict,
|
||||
param: torch.Tensor,
|
||||
original_shape: torch.Size,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
is_moe_param: bool,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
|
||||
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
dp_group (ProcessGroup): The process group of data parallel.
|
||||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
|
||||
|
||||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
"""
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
tp_size = dist.get_world_size(tp_group)
|
||||
current_shape = param.shape
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero and not is_moe_param:
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
||||
# Then gather TP shards.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
||||
if partition_dim is not None:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
@staticmethod
|
||||
def _optimizer_sharder(
|
||||
optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
size_per_shard: int = 1024,
|
||||
only_moe_param: bool = False,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
original_shape = param_info["param2shape"][id(working_param)]
|
||||
state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False,
|
||||
is_moe_param=is_moe_tensor(working_param),
|
||||
)
|
||||
|
||||
if only_moe_param and not is_moe_tensor(working_param):
|
||||
continue
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files that store state tensors of optimizers.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
|
||||
checkpoint (str): Path to save optimizer state_dict
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||
# In this case only let the device with dp_rank == 0 save the model.
|
||||
if not self.use_zero and self.real_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
size_per_shard=size_per_shard,
|
||||
only_moe_param=self.ep_rank != 0,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.coordinator.is_master():
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for param_id, state_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(param_id, state_filename)
|
||||
|
||||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
||||
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
|
||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory."
|
||||
)
|
||||
saved_groups = torch.load(param_group_path)
|
||||
|
||||
updated_groups = []
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
# obtain updated param group
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
||||
updated_groups.append(new_pg)
|
||||
# ep param groups
|
||||
if len(optimizer.optim.param_groups) == len(saved_groups) + 1:
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
|
||||
# If this param's states has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
continue
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(
|
||||
state,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
is_moe_param=is_moe_tensor(working_param),
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def shard_from_complete_optimizer_state(
|
||||
self,
|
||||
state: OrderedDict,
|
||||
current_shape: torch.Size,
|
||||
original_shape: torch.Size,
|
||||
device: torch.device,
|
||||
inplace: bool,
|
||||
is_moe_param: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With complete optimizer states of a specific parameter loaded from checkpoint,
|
||||
slice out the sharded optimizer states kept by current device.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
|
||||
current_shape (torch.Size): The size of parameter after sharding.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
device (torch.device): The destination device of loaded optimizer states.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
|
||||
Returns:
|
||||
OrderedDict: The sharded optimizer state of the given parameter.
|
||||
"""
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# Shard state along tensor parallel group.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
|
||||
if partition_dim is not None:
|
||||
slice_size = current_shape[partition_dim]
|
||||
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
|
||||
|
||||
# Shard state along data parallel group when using Zero.
|
||||
if self.use_zero and not is_moe_param:
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
|
||||
raise NotImplementedError
|
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
|
||||
|
||||
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.setup_ep()
|
||||
|
||||
def setup_ep(self):
|
||||
_, moe_info = MOE_MANAGER.get_info(self.num_experts)
|
||||
ep_group = moe_info.ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_info(p, moe_info)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_ep()
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
selected_experts = selected_experts.t().reshape(-1)
|
||||
selected_experts_idx = selected_experts.argsort()
|
||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
# no need to split
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||
output_states = expert.w2(output_states)
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||
)
|
||||
dispatch_states = dispatch_states[recover_experts_idx]
|
||||
k_hidden_states = dispatch_states.chunk(self.top_k)
|
||||
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
|
||||
for i in range(1, self.top_k):
|
||||
output_states += k_hidden_states[i] * routing_weights[:, i, None]
|
||||
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return output_states, router_logits
|
557
applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
Normal file
557
applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
Normal file
@@ -0,0 +1,557 @@
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss, Module
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralDecoderLayer,
|
||||
MixtralForCausalLM,
|
||||
MixtralModel,
|
||||
MoeCausalLMOutputWithPast,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from .mixtral_layer import EPMixtralSparseMoeBlock
|
||||
|
||||
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
|
||||
|
||||
|
||||
class MixtralPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
class MixtralModelPolicy(MixtralPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralModel,
|
||||
new_forward=MixtralPipelineForwards.mixtral_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
return []
|
||||
|
||||
|
||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
MixtralForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralForCausalLM,
|
||||
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class MixtralPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def mixtral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
past_router_logits: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
||||
|
||||
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_router_logits = (
|
||||
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||
)
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if output_router_logits and past_router_logits is not None:
|
||||
all_router_logits = past_router_logits + all_router_logits
|
||||
if stage_manager.is_last_stage():
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def mixtral_for_causal_lm_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
past_router_logits: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
||||
|
||||
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_router_logits = (
|
||||
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||
)
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = MixtralPipelineForwards.mixtral_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
past_router_logits=past_router_logits,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
router_logits=outputs[-1],
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
if output_router_logits:
|
||||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
84
applications/ColossalMoE/colossal_moe/utils.py
Normal file
84
applications/ColossalMoE/colossal_moe/utils.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load file in JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
|
||||
"""
|
||||
Save as JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="w", encoding="utf-8") as fp:
|
||||
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
save_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
||||
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
load_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
# Update booster params states.
|
||||
booster.load_model(model, os.path.join(load_dir, "modeling"))
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||
return (
|
||||
running_states["epoch"],
|
||||
running_states["step"],
|
||||
running_states["sample_start_index"],
|
||||
)
|
111
applications/ColossalMoE/infer.py
Normal file
111
applications/ColossalMoE/infer.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def parse_args():
|
||||
# basic settings
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-v0.1",
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="ep",
|
||||
choices=["ep"],
|
||||
help="Parallel methos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp32", "bf16", "fp16"],
|
||||
help="The mixed precision training.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
|
||||
# kernel
|
||||
parser.add_argument(
|
||||
"--use_kernel",
|
||||
action="store_true",
|
||||
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_layernorm_kernel",
|
||||
action="store_true",
|
||||
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
config = MixtralConfig.from_pretrained(args.model_name)
|
||||
ep_size = min(dist.get_world_size(), config.num_local_experts)
|
||||
# Set plugin
|
||||
if args.plugin == "ep":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=ep_size,
|
||||
zero_stage=1,
|
||||
precision=args.precision,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build mixtral model
|
||||
model = MixtralForCausalLM.from_pretrained(args.model_name)
|
||||
coordinator.print_on_master(f"Finish load model")
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin)
|
||||
model, _, _, _, _ = booster.boost(model=model)
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
model.eval()
|
||||
|
||||
if coordinator.rank == 0:
|
||||
text = ["Hello my name is"]
|
||||
else:
|
||||
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.module.generate(**inputs, max_new_tokens=20)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
print(f"[{coordinator.rank}] {outputs}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
7
applications/ColossalMoE/infer.sh
Normal file
7
applications/ColossalMoE/infer.sh
Normal file
@@ -0,0 +1,7 @@
|
||||
NUM_GPU=2
|
||||
MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||
|
||||
# ep
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
|
||||
--model_name $MODEL \
|
||||
--plugin "ep" \
|
5
applications/ColossalMoE/requirements.txt
Normal file
5
applications/ColossalMoE/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
colossalai >= 0.3.3
|
||||
torch >= 1.8.1
|
||||
transformers == 4.36.0
|
||||
sentencepiece
|
||||
datasets
|
43
applications/ColossalMoE/setup.py
Normal file
43
applications/ColossalMoE/setup.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def fetch_requirements(path):
|
||||
with open(path, "r") as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
|
||||
|
||||
def fetch_readme():
|
||||
with open("README.md", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def fetch_version():
|
||||
with open("version.txt", "r") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
setup(
|
||||
name="colossal_moe",
|
||||
version=fetch_version(),
|
||||
packages=find_packages(
|
||||
exclude=(
|
||||
"tests",
|
||||
"benchmarks",
|
||||
"*.egg-info",
|
||||
)
|
||||
),
|
||||
description="Colossal-AI MoE",
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
license="Apache Software License 2.0",
|
||||
url="https://github.com/hpcaitech",
|
||||
install_requires=fetch_requirements("requirements.txt"),
|
||||
python_requires=">=3.6",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Environment :: GPU :: NVIDIA CUDA",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: System :: Distributed Computing",
|
||||
],
|
||||
)
|
0
applications/ColossalMoE/tests/__init__.py
Normal file
0
applications/ColossalMoE/tests/__init__.py
Normal file
63
applications/ColossalMoE/tests/test_mixtral_layer.py
Normal file
63
applications/ColossalMoE/tests/test_mixtral_layer.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
|
||||
from torch.testing import assert_close
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import colossalai
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def check_mixtral_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
|
||||
)
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
orig_model = MixtralSparseMoeBlock(config).cuda()
|
||||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||
orig_output, orig_logits = orig_model(x)
|
||||
model = deepcopy(orig_model)
|
||||
model = EPMixtralSparseMoeBlock.from_native_module(model)
|
||||
ep_output, ep_logits = model(x)
|
||||
assert_close(orig_logits, ep_logits)
|
||||
assert_close(orig_output, ep_output)
|
||||
orig_loss = orig_output.mean()
|
||||
orig_loss.backward()
|
||||
ep_loss = ep_output.mean()
|
||||
ep_loss.backward()
|
||||
assert_close(orig_loss, ep_loss)
|
||||
name_to_p = {n: p for n, p in orig_model.named_parameters()}
|
||||
for n, ep_p in model.named_parameters():
|
||||
p = name_to_p[n]
|
||||
if ep_p.grad is not None:
|
||||
assert_close(p.grad, ep_p.grad)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch({}, rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_mixtral_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixtral_moe_layer(2)
|
146
applications/ColossalMoE/tests/test_moe_checkpoint.py
Normal file
146
applications/ColossalMoE/tests/test_moe_checkpoint.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from torch.optim import Adam
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert torch.equal(p1.half(), p2.half())
|
||||
|
||||
|
||||
def get_optimizer_snapshot(optim):
|
||||
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
|
||||
param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [id(p) for p in group["params"]]
|
||||
new_group = {"params": params}
|
||||
for k, v in group.items():
|
||||
if k != "params":
|
||||
new_group[k] = v
|
||||
param_groups.append(new_group)
|
||||
return {
|
||||
"state": state,
|
||||
"param_groups": param_groups,
|
||||
}
|
||||
|
||||
|
||||
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
|
||||
# check param_groups
|
||||
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
|
||||
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
|
||||
assert set(group1.keys()) == set(group2.keys())
|
||||
for k in group1.keys():
|
||||
assert group1[k] == group2[k]
|
||||
# check state
|
||||
assert set(snapshot1["state"].keys()) == set(
|
||||
snapshot2["state"].keys()
|
||||
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
|
||||
for pid in snapshot1["state"].keys():
|
||||
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
|
||||
assert set(state1.keys()) == set(state2.keys())
|
||||
for k in state1.keys():
|
||||
if isinstance(state1[k], torch.Tensor):
|
||||
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
|
||||
else:
|
||||
assert state1[k] == state2[k]
|
||||
|
||||
|
||||
def check_mixtral_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
||||
orig_model = MixtralForCausalLM(config).cuda()
|
||||
model = deepcopy(orig_model)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
ep_size=2,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
microbatch_size=1,
|
||||
zero_stage=1,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
|
||||
# initialize grads
|
||||
data_iter = iter(
|
||||
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
|
||||
)
|
||||
booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
lambda outputs, inputs: outputs.loss,
|
||||
optimizer,
|
||||
)
|
||||
|
||||
# check save model
|
||||
booster.save_model(model, "mixtral_model", shard=True)
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
|
||||
check_model_equal(orig_model, saved_model)
|
||||
saved_model.save_pretrained("mixtral_hf_model")
|
||||
dist.barrier()
|
||||
|
||||
# check load model
|
||||
new_model = MixtralForCausalLM(config).cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||
booster.load_model(new_model, "mixtral_hf_model")
|
||||
check_model_equal(model, new_model)
|
||||
|
||||
# check save optimizer
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
|
||||
dist.barrier()
|
||||
# reset optimizer state
|
||||
for state in optimizer.unwrap().state.values():
|
||||
for v in state.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v.zero_()
|
||||
booster.load_optimizer(optimizer, "mixtral_optim")
|
||||
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch({}, rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
def test_mixtral_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixtral_moe_layer(4)
|
295
applications/ColossalMoE/train.py
Normal file
295
applications/ColossalMoE/train.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_global_loss(loss, booster):
|
||||
global_loss = loss.clone().detach()
|
||||
dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
|
||||
global_loss.div_(booster.plugin.dp_size)
|
||||
return global_loss
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
# basic settings
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-v0.1",
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["hybrid"],
|
||||
help="Parallel methods.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./outputs",
|
||||
help="The path of your saved model after finetuning.",
|
||||
)
|
||||
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size (per dp group) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_interval",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=" The interval (steps) of saving checkpoints.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp32", "bf16", "fp16"],
|
||||
help="The mixed precision training.",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
|
||||
# optim
|
||||
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
|
||||
# lr scheduler
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
|
||||
# zero stage for all plugins
|
||||
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
|
||||
# hybrid plugin
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
|
||||
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
|
||||
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
|
||||
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
|
||||
|
||||
# kernel
|
||||
parser.add_argument(
|
||||
"--use_kernel",
|
||||
action="store_true",
|
||||
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_layernorm_kernel",
|
||||
action="store_true",
|
||||
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||
)
|
||||
|
||||
# load balance
|
||||
parser.add_argument(
|
||||
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
|
||||
)
|
||||
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
|
||||
# communicate overlap
|
||||
parser.add_argument(
|
||||
"--comm_overlap",
|
||||
action="store_true",
|
||||
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
|
||||
)
|
||||
# hierarchical all-to-all
|
||||
parser.add_argument(
|
||||
"--hierarchical_alltoall",
|
||||
action="store_true",
|
||||
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# Set plugin
|
||||
if args.plugin == "hybrid":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=args.pp_size,
|
||||
ep_size=args.ep_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
precision=args.precision,
|
||||
zero_stage=args.zero_stage,
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build Mixtral model
|
||||
model = MixtralForCausalLM.from_pretrained(args.model_name)
|
||||
coordinator.print_on_master(f"Finish init model")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
|
||||
collate_fn = None
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# Set lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps
|
||||
if args.warmup_steps is not None
|
||||
else int(args.num_epochs * len(dataloader) * 0.025),
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# Load ckpt
|
||||
if args.load_checkpoint is not None:
|
||||
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
|
||||
coordinator.print_on_master(f"Finish load optimizer")
|
||||
|
||||
# Start finetuning
|
||||
coordinator.print_on_master(f"Start finetuning")
|
||||
for epoch in range(args.num_epoch):
|
||||
model.train()
|
||||
train_dataloader_iter = iter(dataloader)
|
||||
total_len = len(train_dataloader_iter)
|
||||
with tqdm(
|
||||
range(total_len),
|
||||
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
|
||||
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
# Forward pass
|
||||
outputs = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
model,
|
||||
lambda x, y: x.loss,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = outputs["loss"]
|
||||
global_loss = get_global_loss(loss, booster)
|
||||
if coordinator._local_rank == "0":
|
||||
pbar.set_postfix({"Loss": global_loss.item()})
|
||||
else:
|
||||
# Forward pass
|
||||
data = next(train_dataloader_iter)
|
||||
data = move_to_cuda(data, torch.cuda.current_device())
|
||||
outputs = model(**data)
|
||||
loss = outputs["loss"]
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Apply load balance
|
||||
# if (
|
||||
# args.load_balance
|
||||
# and args.load_balance_interval > 0
|
||||
# and (step + 1) % args.load_balance_interval == 0
|
||||
# ):
|
||||
# coordinator.print_on_master(f"Apply load balance")
|
||||
# apply_load_balance(model, optimizer)
|
||||
# save ckeckpoint
|
||||
if (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
save_checkpoint(
|
||||
args.output_path,
|
||||
booster,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
epoch,
|
||||
step,
|
||||
args.batch_size,
|
||||
coordinator,
|
||||
)
|
||||
|
||||
# save checkpoint at the end of each epochs
|
||||
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
|
||||
# Finish training
|
||||
coordinator.print_on_master(f"Finish training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
19
applications/ColossalMoE/train.sh
Normal file
19
applications/ColossalMoE/train.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
NUM_GPU=8
|
||||
MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||
SEQ_LENGTH=2048
|
||||
BATCH_SIZE=1
|
||||
LR=0.00001
|
||||
|
||||
# hybrid
|
||||
# torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \
|
||||
train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name $MODEL \
|
||||
--plugin "hybrid" \
|
||||
--batch_size $BATCH_SIZE \
|
||||
--lr $LR \
|
||||
--zero_stage 1 \
|
||||
--pp_size 2 \
|
||||
--dp_size 1 \
|
||||
--ep_size 8 \
|
1
applications/ColossalMoE/version.txt
Normal file
1
applications/ColossalMoE/version.txt
Normal file
@@ -0,0 +1 @@
|
||||
1.0.0
|
Reference in New Issue
Block a user