[MoE/ZeRO] Moe refactor with zero refactor (#5821)

* [moe] removed openmoe-coupled code and rectify mixstral code (#5471)

* [Feauture] MoE refractor; Intergration with Mixtral  (#5682)

* cherry pick from refractor-moe branch

* tests passed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support ep + zero

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add mixtral auto policy & move pipeline forward code to modeling folder

* [moe refactor] modify kernel test without Route Class

* [moe refactor] add moe tensor test path environment variable to github workflow

* fix typos

* fix moe test bug due to the code rebase

* [moe refactor] fix moe zero test, and little bug in low level zero

* fix typo

* add moe tensor path to github workflow

* remove some useless code

* fix typo & unify global variable XX_AXIS logic without using -1

* fix typo & prettifier the code

* remove print code & support zero 2 test

* remove useless code

* reanme function

* fix typo

* fix typo

* Further improve the test code

* remove print code

* [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test

* [moe refactor] skip some unit test which will be refactored later

* [moe refactor] fix unit import error

* [moe refactor] fix circular import issues

* [moe refactor] remove debug code

* [moe refactor] update github workflow

* [moe/zero] refactor low level optimizer (#5767)

* [zero] refactor low level optimizer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] MoE refactor with newest version of ZeRO (#5801)

* [zero] remove redundant members in BucketStore (#5802)

* [zero] align api with previous version

* [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819)

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* [hotfix]Solve the compatibility issue of zero refactor (#5823)

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* Modify function parameter names to resolve compatibility issues

* [zero] fix missing hook removal (#5824)

* [MoE] Resolve .github conflict (#5829)

* [Fix/Example] Fix Llama Inference Loading Data Type (#5763)

* [fix/example] fix llama inference loading dtype

* revise loading dtype of benchmark llama3

* [release] update version (#5752)

* [release] update version

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [test] fix ddp plugin test

* [test] fix gptj and rpc test

* [devops] fix cuda ext compatibility

* [inference] fix flash decoding test

* [inference] fix flash decoding test

* fix (#5765)

* [test] Fix/fix testcase (#5770)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [Hotfix] Add missing init file in inference.executor (#5774)

* [CI/tests] simplify some test case to reduce testing time (#5755)

* [ci/tests] simplify some test case to reduce testing time

* [ci/tests] continue to remove test case to reduce ci time cost

* restore some test config

* [ci/tests] continue to reduce ci time cost

* [misc] update dockerfile (#5776)

* [misc] update dockerfile

* [misc] update dockerfile

* [devops] fix docker ci (#5780)

* [Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist

* [hotfix] fix llama flash attention forward (#5777)

* [misc] Accelerate CI for zero and dist optim (#5758)

* remove fp16 from lamb

* remove d2h copy in checking states

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Test/CI] remove test cases to reduce CI duration (#5753)

* [test] smaller gpt2 test case

* [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py

* [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py

* [test] reduce test cases tests/test_zero/test_gemini/test_optim.py

* Revert "[test] smaller gpt2 test case"

Some tests might depend on the size of model (num of chunks)

This reverts commit df705a5210.

* [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py

* [CI] smaller test model for two mwo the two modifid cases

* [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there

* [hotfix] fix testcase in test_fx/test_tracer (#5779)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [fix] fix test_deepfm_model & test_dlrf_model;

* [fix] fix test_hf_albert & test_hf_gpt;

* [gemini] optimize reduce scatter d2h copy (#5760)

* [gemini] optimize reduce scatter d2h copy

* [fix] fix missing reduce variable

* [refactor] remove legacy async reduce scatter code

* [gemini] missing sync

* Revert "[refactor] remove legacy async reduce scatter code"

This reverts commit 58ad76d466.

* [gemini] further optimize with async all reduce

* [fix] pass flag from manager to chunk

* Allow building cuda extension without a device. (#5535)

Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are.

* [misc] fix dist logger (#5782)

* [install]fix setup (#5786)

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update requirements (#5787)

* [shardformer] fix import (#5788)

* upgrade colossal-chat support tp_group>1, add sp for sft

* upgrade ppo dpo rm script

* run pre-commit

* moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy

* fix training script

* fix ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix transformers version

* remove duplicated test

* fix datasets version

* remove models that require huggingface auth from ci

* remove local data path

* update ci

* remove baichuan from template test due to transformer version conflict

* merge

* Refactor modeling by adding attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix tests and naming

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Clean up

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* replace the customized dataloader setup with the build-in one

* replace the customized dataloader setup with the build-in one

* Remove flash attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* fix readme

* Fix test import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* update sft trainning script

* [Inference]refactor baichuan (#5791)

* refactor baichuan

* remove unused code and add TODO for lazyinit

* [test] fix chatglm test kit (#5793)

* [shardformer] fix modeling of bloom and falcon (#5796)

* [test] fix qwen2 pytest distLarge (#5797)

* [Inference] Fix flash-attn import and add model test (#5794)

* Fix torch int32 dtype

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix flash-attn import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add generalized model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Remove exposed path to model

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add default value for use_flash_attn

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Rename model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* [Gemini] Use async stream to prefetch and h2d data moving (#5781)

* use async stream to prefetch and h2d data moving

* Remove redundant code

* [gemini] quick fix on possible async operation (#5803)

* [gemini] quick fix on possible async operation

* [gemini] quick fix on possible async operation

* [shardformer] upgrade transformers to 4.39.3 (#5815)

* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)

* [shardformer] fix modeling of gpt2 and gptj

* [shardformer] fix whisper modeling

* [misc] update requirements

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* [shardformer]upgrade transformers for mistral (#5808)

* upgrade transformers for mistral

* fix

* fix

* [shardformer]upgrade transformers for llama (#5809)

* update transformers

fix

* fix

* fix

* [inference] upgrade transformers (#5810)

* update transformers

fix

* fix

* fix

* fix

* fix

* [gemini] update transformers for gemini (#5814)

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* Support 4d parallel + flash attention (#5789)

* support tp + sp + pp

* remove comments

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>

* [zero] fix hook bug

* [zero] add low level optimizer back (#5839)

* [zero] fix param & refactor

* [zero] add back original low level opt

* [zero] remove moe related

* [zero] pass zero tests

* [zero] refactor

* [chore] add del func back

* [zero] comments and naming (#5840)

* [zero] modify api (#5843)

* [zero] modify api

* [test] remove _grad_store access in tests

* [test] fix (#5857)

* [CI] skip openmoe CI check

* [CI] fox pre-commit

* [zero] remove redundant memebr init (#5862)

* [misc] remove useless code, modify the pg mesh implementation

* [misc] remove useless code, modify the pg mesh implementation

* [misc] use tempfile

* resolve conflict with main branch

* [misc] use tempfile in test_moe_checkpoint.py

* [misc] remove useless code, add assertion about sequence parallel, move logger into function

* [misc] remove useless code

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
This commit is contained in:
Haze188
2024-06-28 14:00:08 +08:00
committed by GitHub
parent 773d9f964a
commit 416580b314
69 changed files with 1780 additions and 3076 deletions

View File

@@ -1,629 +0,0 @@
import copy
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict_shards,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
self.ep_group = moe_info.ep_group
self.ep_size = moe_info.ep_size
self.ep_rank = moe_info.ep_rank
self.real_dp_rank = moe_info.dp_rank
@staticmethod
def _model_sharder(
model: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
param_name_pattern: Optional[str] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
if param_name_pattern is not None and param_name_pattern not in name:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
# Save buffers.
for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model = model.unwrap()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
if self.real_dp_rank != 0:
dist.barrier()
return
# ep_rank 0 saves all the parameters and buffers.
# other ep_ranks save only experts
ep_param_pattern = "experts." if self.ep_rank != 0 else None
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
dist.barrier()
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
weights_name = weights_name.replace(
".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors"
)
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
dist.barrier()
return
dist.barrier()
# The global master rank integrates the index files and clean the folder.
if self.coordinator.is_master():
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for weight, weight_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(weight, weight_filename)
final_index_file.write_index_file(final_index_file_path)
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
@staticmethod
def gather_from_sharded_optimizer_state(
state: OrderedDict,
param: torch.Tensor,
original_shape: torch.Size,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_zero: bool,
inplace: bool,
is_moe_param: bool,
device: torch.device = torch.device("cpu"),
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
Args:
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
dp_size = dist.get_world_size(dp_group)
tp_size = dist.get_world_size(tp_group)
current_shape = param.shape
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero and not is_moe_param:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
state_[k] = v.detach().clone().to(device)
return state_
@staticmethod
def _optimizer_sharder(
optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
size_per_shard: int = 1024,
only_moe_param: bool = False,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
master_to_working_map = optimizer.get_master_to_working_map()
for param, state in optimizer.optim.state.items():
if param is None:
continue
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info["param2id"][id(working_param)]
original_shape = param_info["param2shape"][id(working_param)]
state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
is_moe_param=is_moe_tensor(working_param),
)
if only_moe_param and not is_moe_tensor(working_param):
continue
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files that store state tensors of optimizers.
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.real_dp_rank != 0:
dist.barrier()
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
tp_group=self.tp_group,
size_per_shard=size_per_shard,
only_moe_param=self.ep_rank != 0,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
dist.barrier()
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
dist.barrier()
return
dist.barrier()
# The global master rank integrates the index files and clean the folder.
if self.coordinator.is_master():
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for param_id, state_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(param_id, state_filename)
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
# ep param groups
if len(optimizer.optim.param_groups) == len(saved_groups) + 1:
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
# If this param's states has been loaded before, directly return.
if filename in loaded_file:
continue
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
is_moe_param=is_moe_tensor(working_param),
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def shard_from_complete_optimizer_state(
self,
state: OrderedDict,
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
inplace: bool,
is_moe_param: bool,
) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
Args:
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
current_shape (torch.Size): The size of parameter after sharding.
original_shape (torch.Size): The size of parameter before sharding.
device (torch.device): The destination device of loaded optimizer states.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
if partition_dim is not None:
slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
# Shard state along data parallel group when using Zero.
if self.use_zero and not is_moe_param:
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
return state_
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
raise NotImplementedError
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
raise NotImplementedError
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
raise NotImplementedError

View File

@@ -1,92 +0,0 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext
from colossalai.moe import MOE_MANAGER
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config):
super().__init__(config)
self.setup_ep()
def setup_ep(self):
_, moe_info = MOE_MANAGER.get_info(self.num_experts)
ep_group = moe_info.ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
set_moe_tensor_info(p, moe_info)
@staticmethod
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_ep()
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
selected_experts = selected_experts.t().reshape(-1)
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
)
dispatch_states = dispatch_states[recover_experts_idx]
k_hidden_states = dispatch_states.chunk(self.top_k)
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
for i in range(1, self.top_k):
output_states += k_hidden_states[i] * routing_weights[:, i, None]
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
return output_states, router_logits

View File

@@ -1,557 +0,0 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MoeCausalLMOutputWithPast,
_prepare_4d_causal_attention_mask,
load_balancing_loss_func,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.shard import ShardConfig
from .mixtral_layer import EPMixtralSparseMoeBlock
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
class MixtralPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=MixtralDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=MixtralModel,
)
if self.shard_config.enable_flash_attention:
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class MixtralModelPolicy(MixtralPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralModel,
new_forward=MixtralPipelineForwards.mixtral_model_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
return []
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
]
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralForCausalLM,
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
class MixtralPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
"""
@staticmethod
def mixtral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
past_router_logits: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
output_attentions,
output_router_logits,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
output_router_logits,
use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits
if stage_manager.is_last_stage():
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
@staticmethod
def mixtral_for_causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
past_router_logits: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = MixtralPipelineForwards.mixtral_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
past_router_logits=past_router_logits,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out

View File

@@ -2,8 +2,6 @@ import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@@ -70,8 +68,6 @@ def main():
ep_size=ep_size,
zero_stage=1,
precision=args.precision,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
)

View File

@@ -1,5 +1,6 @@
NUM_GPU=2
MODEL="mistralai/Mixtral-8x7B-v0.1"
# MODEL="mistralai/Mixtral-8x7B-v0.1"
MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
# ep
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \

View File

@@ -1,63 +0,0 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
from torch.testing import assert_close
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import colossalai
from colossalai.moe import MOE_MANAGER
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
MOE_MANAGER.setup(
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
)
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
)
torch.manual_seed(0)
orig_model = MixtralSparseMoeBlock(config).cuda()
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
model = EPMixtralSparseMoeBlock.from_native_module(model)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
orig_loss.backward()
ep_loss = ep_output.mean()
ep_loss.backward()
assert_close(orig_loss, ep_loss)
name_to_p = {n: p for n, p in orig_model.named_parameters()}
for n, ep_p in model.named_parameters():
p = name_to_p[n]
if ep_p.grad is not None:
assert_close(p.grad, ep_p.grad)
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()
@pytest.mark.parametrize("world_size", [2, 4])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral_moe_layer(2)

View File

@@ -1,146 +0,0 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from torch.optim import Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.equal(p1.half(), p2.half())
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
for group in optim.param_groups:
params = [id(p) for p in group["params"]]
new_group = {"params": params}
for k, v in group.items():
if k != "params":
new_group[k] = v
param_groups.append(new_group)
return {
"state": state,
"param_groups": param_groups,
}
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
# check param_groups
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
assert set(group1.keys()) == set(group2.keys())
for k in group1.keys():
assert group1[k] == group2[k]
# check state
assert set(snapshot1["state"].keys()) == set(
snapshot2["state"].keys()
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
for pid in snapshot1["state"].keys():
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
assert set(state1.keys()) == set(state2.keys())
for k in state1.keys():
if isinstance(state1[k], torch.Tensor):
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
else:
assert state1[k] == state2[k]
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
num_attention_heads=2,
num_key_value_heads=2,
)
torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
orig_model = MixtralForCausalLM(config).cuda()
model = deepcopy(orig_model)
optimizer = Adam(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=2,
ep_size=2,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
microbatch_size=1,
zero_stage=1,
)
booster = Booster(plugin=plugin)
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
# initialize grads
data_iter = iter(
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
)
booster.execute_pipeline(
data_iter,
model,
lambda outputs, inputs: outputs.loss,
optimizer,
)
# check save model
booster.save_model(model, "mixtral_model", shard=True)
dist.barrier()
if dist.get_rank() == 0:
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
check_model_equal(orig_model, saved_model)
saved_model.save_pretrained("mixtral_hf_model")
dist.barrier()
# check load model
new_model = MixtralForCausalLM(config).cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, "mixtral_hf_model")
check_model_equal(model, new_model)
# check save optimizer
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
snapshot = get_optimizer_snapshot(optimizer.unwrap())
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
dist.barrier()
# reset optimizer state
for state in optimizer.unwrap().state.values():
for v in state.values():
if isinstance(v, torch.Tensor):
v.zero_()
booster.load_optimizer(optimizer, "mixtral_optim")
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()
@pytest.mark.parametrize("world_size", [4])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral_moe_layer(4)

View File

@@ -2,13 +2,11 @@ import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralForCausalLM
from utils import load_checkpoint, move_to_cuda, save_checkpoint
import colossalai
from colossalai.booster import Booster
@@ -155,12 +153,10 @@ def main():
pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
custom_policy=MixtralForCausalLMPolicy(),
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
)
else:

View File

@@ -20,6 +20,7 @@ resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
"""
import json
from typing import Any, Mapping