From 89049b0d899477a3b31f02b31fde1a839e31c6fc Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Mon, 15 Apr 2024 18:06:18 +0800 Subject: [PATCH 01/28] [doc] fix ColossalMoE readme (#5599) * fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- applications/ColossalMoE/README.md | Bin 6475 -> 1023 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/applications/ColossalMoE/README.md b/applications/ColossalMoE/README.md index ba864d1dff8b2c52b3a5c45261f37586ecdb5bc1..c3c214789f54c90a4b10be96b63619d8c4a42c6f 100644 GIT binary patch delta 7 OcmX?Y^q+mhe`Wv==mWL@ literal 6475 zcmeH`!EW0y42E~sQxM3Z8G)p7IE34}rn4Hi zm2iEn2Elbiq3TQRRE@E^TnZHdNievO zfl(i5vj*zbKD)TCVC}iKn<|OYw{{$0n4KcH`Hm zNg*_!;A^v3z}uU}!@J}zkA^;x{02E94s>V7ev3ZYd3f3c2+EB{!j?v$_d4<49^r-t z|9^*cc0ogWQ+|a&Ay5bu0);>!PzV$Pg+L)t2owT^Kp{{F6as}nAy5bu0)HdG{s14p BSs?%b From 3788fefc7a24ae4da18a29ca69e6d3b1473d306c Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 16 Apr 2024 17:49:21 +0800 Subject: [PATCH 02/28] [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements --- .../low_level/bookkeeping/bucket_store.py | 2 + colossalai/zero/low_level/low_level_optim.py | 67 +++++++++++++++---- requirements/requirements.txt | 2 +- 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index f395fc60e..2ebc704f7 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -11,7 +11,9 @@ from .base_store import BaseStore class BucketStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) + self.reset_all() + def reset_all(self) -> None: # init self.current_group_id = 0 self._num_elements_in_bucket = 0 diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index bbbaf13b5..cbcf72697 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -40,7 +40,13 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): max_scale: float = 2**32, ) -> None: super().__init__( - initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, ) self.num_working_param_groups = num_working_param_groups self.grad_store = grad_store @@ -273,11 +279,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Backward Reduction Hook # ########################### - def _grad_handler(self, param, group_id, grad): + def _grad_handler(self, group_id, param): # if run with no_sync context, would not sync grad when backward if self.require_grad_sync: self._add_to_bucket(param, group_id) - return grad def _attach_reduction_hook(self): # we iterate over the working params @@ -286,7 +291,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - param.register_hook(partial(self._grad_handler, param, group_id)) + param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id)) ####################### # Reduction Functions # @@ -415,7 +420,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) self._update_partitoned_grad( - non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1 + non_moe_grad_in_bucket_current_rank, + recieved_grad, + group_id, + 1, ) if len(moe_grad_list) > 0: @@ -423,7 +431,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) ) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg) + dist.reduce_scatter( + recieved_grad, + flat_grads_list, + group=self.moe_extra_dp_pg, + ) param_slice = self._world_size // self.moe_extra_dp_pg_size recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) for split_recieved_grad in recieved_grad: @@ -444,14 +456,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._add_grad(grad, self._world_size, group_id, param_id, rank) def _update_partitoned_grad( - self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int + self, + origin_grad_list: List, + flat_grad: torch.Tensor, + group_id: int, + partition_num: int, ) -> None: sync_tensor(flat_grad, origin_grad_list) for grad in origin_grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) self._add_grad(grad, partition_num, group_id, param_id) - def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None: + def _add_grad( + self, + grad: torch.Tensor, + partition_num: int, + group_id: int, + param_id: int, + rank: int = 0, + ) -> None: if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: @@ -534,6 +557,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if param.grad is not None: param.grad.detach() param.grad.zero_() + self._bucket_store.reset_all() #################### # Update Parameter # @@ -655,14 +679,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for _ in range(self.moe_extra_dp_pg_size) ] dist.all_gather( - all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg + all_splited_param, + splited_param.to(device).to(self._dtype), + group=self.moe_extra_dp_pg, ) else: all_splited_param = [ torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg) + dist.all_gather( + all_splited_param, + splited_param.to(device).to(self._dtype), + group=self.dp_pg, + ) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -685,7 +715,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float + [float(total_norm)], + device=get_accelerator().get_current_device(), + dtype=torch.float, ) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -698,10 +730,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Sum across all model parallel GPUs. total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float + [float(total_norm_exponentiated)], + device=get_accelerator().get_current_device(), + dtype=torch.float, ) torch.distributed.all_reduce( - total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_pg, ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) @@ -920,5 +956,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: if hasattr(self, "moe_master_to_working_map"): - return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} + return { + **self._param_store.master_to_working_param, + **self.moe_master_to_working_map, + } return self._param_store.master_to_working_param diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 095617d76..fd97f5c5a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=1.12 +torch>=2.1.0 safetensors einops pydantic From a0ad587c24545b82b5412553be2dfe17bbfbda26 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 18 Apr 2024 16:10:18 +0800 Subject: [PATCH 03/28] [shardformer] refactor embedding resize (#5603) * [branch rebase] rebase main to Feature/resize_embedding (#5554) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme --------- Co-authored-by: Hongxin Liu Co-authored-by: digger yu Co-authored-by: binmakeswell * [CI] run pre-commit (#5577) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme * run pre-commit --------- Co-authored-by: Hongxin Liu Co-authored-by: digger yu Co-authored-by: binmakeswell * [rebase] rebase main to resize-embedding (#5581) * [release] grok-1 314b inference (#5490) * [release] grok-1 inference * [release] grok-1 inference * [release] grok-1 inference * [example] update Grok-1 inference (#5495) * revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url * [hotfix] set return_outputs=False in examples and polish code (#5404) * fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value * [release] grok-1 inference benchmark (#5500) * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [shardformer]Fix lm parallel. (#5480) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix * [fix] fix grok-1 example typo (#5506) * [devops] fix example test ci (#5504) * Fix ColoTensorSpec for py11 (#5440) * fixed layout converter caching and updated tester * Empty-Commit * [shardformer] update colo attention to support custom mask (#5510) * [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests * [format] applied code formatting on changed files in pull request 5510 (#5517) Co-authored-by: github-actions * [shardformer] fix pipeline forward error if custom layer distribution is used (#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen * [Fix] Grok-1 use tokenizer from the same pretrained path (#5532) * [fix] use tokenizer from the same pretrained path * trust remote code * [ColossalChat] Update RLHF V2 (#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li * [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508) * feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests * fix incorrect sharding without zero (#5545) Co-authored-by: Edenzzzz * [shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 * [hotfix] quick fixes to make legacy tutorials runnable (#5559) Co-authored-by: Edenzzzz * [fix] fix typo s/muiti-node /multi-node etc. (#5448) * [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548) * [devops] remove post commit ci (#5566) * [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: binmakeswell Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen Co-authored-by: Hongxin Liu Co-authored-by: Rocky Duan Co-authored-by: Edenzzzz Co-authored-by: Edenzzzz Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions Co-authored-by: Insu Jang Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li Co-authored-by: Zhongkai Zhao Co-authored-by: linsj20 Co-authored-by: digger yu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]enable padding vocabulary size. (#5489) * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * padding vocab * padding vocabe * fix * fix * fxi * test ci * fix fix fix fix * fix fix * fix * fix * Update hybrid_parallel_plugin.py fix fix fix * fix fix * fix fix * fix * resolve super init resolve super init resolve super init resolve super init * resolve comments * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * vocab checkpointio * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix fix fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * padding vocab * fix * fix fix * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * cherry-pick * revert moe modify * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix fix fix fix fix fix fix fix * resolve comments resolve comments resolve comments resolve comments resolve comments * ptensor ptensor resolve comments fix fix fix fix fix resolve comments resolve comments resolve comments resolve comments resolve comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hongxin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rebase * fix rebase --------- Co-authored-by: Hongxin Liu Co-authored-by: digger yu Co-authored-by: binmakeswell Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen Co-authored-by: Rocky Duan Co-authored-by: Edenzzzz Co-authored-by: Edenzzzz Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions Co-authored-by: Insu Jang Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li Co-authored-by: Zhongkai Zhao Co-authored-by: linsj20 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/booster/plugin/gemini_plugin.py | 6 +- .../booster/plugin/hybrid_parallel_plugin.py | 11 +- .../hybrid_parallel_checkpoint_io.py | 29 ++- colossalai/checkpoint_io/utils.py | 9 + colossalai/shardformer/layer/__init__.py | 7 +- colossalai/shardformer/layer/embedding.py | 111 +++++++-- colossalai/shardformer/layer/linear.py | 220 +++++++++++++++++- colossalai/shardformer/layer/loss.py | 32 ++- .../shardformer/layer/parallel_module.py | 192 ++++++++++++++- colossalai/shardformer/modeling/gpt2.py | 13 +- colossalai/shardformer/modeling/llama.py | 11 +- .../shardformer/policies/base_policy.py | 9 + colossalai/shardformer/policies/bert.py | 52 +++-- colossalai/shardformer/policies/blip2.py | 72 ++++-- colossalai/shardformer/policies/bloom.py | 47 ++-- colossalai/shardformer/policies/chatglm2.py | 40 ++-- colossalai/shardformer/policies/falcon.py | 49 ++-- colossalai/shardformer/policies/gpt2.py | 71 ++++-- colossalai/shardformer/policies/gptj.py | 55 +++-- colossalai/shardformer/policies/llama.py | 53 +++-- colossalai/shardformer/policies/mistral.py | 57 +++-- colossalai/shardformer/policies/opt.py | 64 +++-- colossalai/shardformer/policies/t5.py | 101 ++++++-- colossalai/shardformer/policies/whisper.py | 45 ++-- colossalai/shardformer/shard/shard_config.py | 3 +- .../tensor/d_tensor/layout_converter.py | 17 +- colossalai/tensor/padded_tensor/__init__.py | 3 + colossalai/tensor/padded_tensor/api.py | 128 ++++++++++ colossalai/testing/comparison.py | 2 +- colossalai/zero/gemini/gemini_ddp.py | 22 ++ colossalai/zero/gemini/gemini_optimizer.py | 45 +++- ...st_hybrid_parallel_plugin_checkpoint_io.py | 3 +- .../test_vocab_parallel_embedding_1d.py | 2 +- tests/test_shardformer/test_model/_utils.py | 11 +- .../test_model/test_shard_t5.py | 3 +- tests/test_tensor/test_padded_tensor.py | 46 ++++ 36 files changed, 1352 insertions(+), 289 deletions(-) create mode 100644 colossalai/tensor/padded_tensor/__init__.py create mode 100644 colossalai/tensor/padded_tensor/api.py create mode 100644 tests/test_tensor/test_padded_tensor.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 6c5033773..442ac4a8d 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -44,10 +44,10 @@ ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A mapping from integer param_id to param32 shape. - if optim is None: return {} param_info = {"id2shape": {}} + start_index = 0 for group in optim.param_groups: for param_id, param in enumerate(group["params"], start_index): @@ -527,7 +527,7 @@ class GeminiPlugin(DPPluginBase): dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - optimizer_params_info = get_param_info(optimizer) + params_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -558,7 +558,7 @@ class GeminiPlugin(DPPluginBase): **self.zero_optim_config, **self.optim_kwargs, tp_group=self.tp_group, - optimizer_params_info=optimizer_params_info, + params_info=params_info, verbose=self.verbose, ) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 29cec7cfd..8d12eb806 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer): if optim is None: return {} - param_info = { - "param_groups": [], - "param2id": {}, - "id2param": {}, - "param2shape": {}, - } + param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} start_index = 0 for group in optim.param_groups: packed_group = {k: v for k, v in group.items() if k != "params"} @@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase): num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. + make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. + """ def __init__( @@ -989,6 +986,7 @@ class HybridParallelPlugin(PipelinePluginBase): num_model_chunks: int = 1, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 64, ) -> None: super().__init__() assert ( @@ -1095,6 +1093,7 @@ class HybridParallelPlugin(PipelinePluginBase): sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, ) self.amp_config = dict( diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 808227249..7946d9b9c 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -14,6 +14,12 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.utils import get_current_device from .general_checkpoint_io import GeneralCheckpointIO @@ -32,6 +38,7 @@ from .utils import ( save_param_groups, save_state_dict, save_state_dict_shards, + search_padding_dim, search_tp_partition_dim, sharded_optimizer_loading_epilogue, ) @@ -89,6 +96,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if param is None: continue # Gather tensor pieces when using tensor parallel. + if is_padded_tensor(param): + param = to_unpadded_tensor(param) param_ = gather_distributed_param(param, keep_vars=False) block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: @@ -231,7 +240,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # When pipeline is used, each stage produces its own shard files and index files. # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. - final_index_file_path = copy.deepcopy(save_index_file) tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) @@ -251,6 +259,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): use_safetensors=use_safetensors, use_pp_format=True, ) + if control_saving: assert ( self.dp_rank == 0 and self.tp_rank == 0 @@ -867,6 +876,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): dist.all_gather(gather_tensor, v, group=tp_group) v = torch.cat(gather_tensor, dim=partition_dim) + padding_dim = search_padding_dim(v.shape, original_shape) + if padding_dim is not None: + v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim) + v = to_unpadded_tensor(v) + state_[k] = v.detach().clone().to(device) return state_ @@ -899,6 +913,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if isinstance(v, torch.Tensor) and k != "step": # Shard state along tensor parallel group. partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + global_shape = current_shape + if partition_dim is not None: + # pad embedding params + global_shape = ( + *current_shape[:partition_dim], + current_shape[partition_dim] * self.tp_size, + *current_shape[partition_dim + 1 :], + ) + + padding_dim = search_padding_dim(global_shape, original_shape) + if padding_dim is not None: + v = to_padded_tensor(v, global_shape[padding_dim], padding_dim) + if partition_dim is not None: slice_size = current_shape[partition_dim] v = v.split(slice_size, dim=partition_dim)[self.tp_rank] diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 2a1d4de9b..6197be9d1 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz return partition_dim +def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]: + padding_dim = None + for dim, length in enumerate(global_shape): + if length > original_shape[dim]: + padding_dim = dim + break + return padding_dim + + # ====================================== # Helper classes and functions for saving shard file # ====================================== diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7b8aa5380..f17fad1b6 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,8 +1,8 @@ from ._operation import all_to_all_comm from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput -from .embedding import Embedding1D, VocabParallelEmbedding1D -from .linear import Linear1D_Col, Linear1D_Row +from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D +from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule @@ -25,6 +25,9 @@ __all__ = [ "FusedRMSNorm", "FusedLinear1D_Col", "ParallelModule", + "PaddingEmbedding", + "PaddingLMHead", + "VocabParallelLMHead1D", "AttnMaskType", "ColoAttention", "all_to_all_comm", diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index d081b2040..cb7eceae4 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import ( ) from ._operation import gather_forward_split_backward, reduce_forward -from .parallel_module import ParallelModule +from .parallel_module import PaddingParallelModule, ParallelModule from .utils import create_randomizer_with_offset -__all__ = ["Embedding1D", "VocabParallelEmbedding1D"] +__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"] class Embedding1D(ParallelModule): @@ -161,7 +161,80 @@ class Embedding1D(ParallelModule): return output_parallel -class VocabParallelEmbedding1D(ParallelModule): +class PaddingEmbedding(PaddingParallelModule): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + weight: Optional[nn.Parameter] = None, + make_vocab_size_divisible_by: int = 64, + *args, + **kwargs, + ): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.embed_args = args + self.embed_kwargs = kwargs + self.padding_idx = padding_idx + if num_embeddings % make_vocab_size_divisible_by != 0: + self.num_embeddings = ( + num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by) + ) + # create weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + super().__init__(self.num_embeddings, num_embeddings, weight) + + if weight is None: + self.reset_parameters() + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + @staticmethod + def from_native_module( + module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + LazyInitContext.materialize(module) + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + # create the parallel module + padding_embedding = PaddingEmbedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + weight=module.weight, + *args, + **kwargs, + ) + + return padding_embedding + + +class VocabParallelEmbedding1D(PaddingParallelModule): r"""Embedding parallelized in the vocabulary dimension. Args: @@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule): process_group: ProcessGroup = None, weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), + make_vocab_size_divisible_by: int = 64, *args, **kwargs, ): - super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.embed_args = args @@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule): tensor_parallel_size = dist.get_world_size(group=process_group) tensor_parallel_rank = dist.get_rank(group=process_group) - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings = self.num_embeddings_per_partition + # generate weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + # calculate new padding size + multiple = make_vocab_size_divisible_by * tensor_parallel_size + if num_embeddings % multiple != 0: + self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple) + + # resize vocabulary size + super().__init__(self.num_embeddings, num_embeddings, weight) + + # deal with tensor parallelism + self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size) self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition @@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule): seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - # parameter - if weight is None: - factory_kwargs = {"device": device, "dtype": dtype} - self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) - else: - weight.data = weight.data.to(device=device, dtype=dtype) - self.weight = weight if not is_distributed_tensor(self.weight): sharded_weight = shard_rowwise(self.weight.data, process_group) sharded_tensor_to_existing_param(sharded_weight, self.weight) @@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule): @staticmethod def from_native_module( module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: + ) -> PaddingParallelModule: r""" Convert a native pytorch embedding module to a parallel module. """ @@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule): # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding( masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs ) - # Mask the output embedding. embedding_output = output_parallel.clone() embedding_output[input_mask, :] = 0.0 diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 7c8619ad8..37c754241 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -32,7 +32,7 @@ from ._operation import ( reducescatter_forward_gather_backward, split_forward_gather_backward, ) -from .parallel_module import ParallelModule +from .parallel_module import PaddingParallelModule, ParallelModule from .utils import create_randomizer_with_offset __all__ = ["Linear1D_Col", "Linear1D_Row"] @@ -84,8 +84,9 @@ class Linear1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs, ): - super().__init__() + super().__init__(weight=weight, bias_=bias_, **kwargs) # Keep input parameters self.in_features = in_features @@ -118,6 +119,7 @@ class Linear1D_Col(ParallelModule): else: weight.data = weight.data.to(device=device, dtype=dtype) self.weight = weight + if not is_distributed_tensor(self.weight): sharded_weight = shard_rowwise(self.weight.data, self.process_group) sharded_tensor_to_existing_param(sharded_weight, self.weight) @@ -140,7 +142,7 @@ class Linear1D_Col(ParallelModule): @staticmethod def from_native_module( - module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. @@ -173,7 +175,6 @@ class Linear1D_Col(ParallelModule): process_group=process_group, weight=module.weight, bias_=module.bias, - *args, **kwargs, ) @@ -322,7 +323,7 @@ class Linear1D_Row(ParallelModule): @staticmethod def from_native_module( - module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. @@ -356,7 +357,6 @@ class Linear1D_Row(ParallelModule): process_group=process_group, weight=module.weight, bias_=module.bias, - *args, **kwargs, ) @@ -439,3 +439,211 @@ class Linear1D_Row(ParallelModule): return output else: return output, self.bias + + +class PaddingLMHead(PaddingParallelModule): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + make_vocab_size_divisible_by: int = 64, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + + if out_features % make_vocab_size_divisible_by != 0: + self.out_features = ( + out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by) + ) + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + else: + bias_ = None + + # resize embeddings + super().__init__(self.out_features, out_features, weight, bias_) + + if weight is None: + self.reset_parameters(weight_initializer, bias_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + # ensure only one process group is passed + + lm_head_linear = PaddingLMHead( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return lm_head_linear + + def forward(self, input: Tensor) -> Tensor: + output = F.linear(input, self.weight, self.bias) + output = output[..., : self.old_num_embeddings] + return output + + +class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + make_vocab_size_divisible_by: int = 64, + **kwargs, + ): + # create weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs)) + if bias: + if bias_ is None: + bias_ = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + bias_ = None + + # calculate new vocab size + self.tensor_parallel_size = dist.get_world_size(group=process_group) + new_out_features = out_features + multiple = make_vocab_size_divisible_by * self.tensor_parallel_size + if out_features % multiple != 0: + new_out_features = out_features + multiple - (out_features % multiple) + + super().__init__( + in_features=in_features, + out_features=new_out_features, + bias=bias, + device=device, + process_group=process_group, + weight=weight, + bias_=bias_, + **kwargs, + new_num_embeddings=new_out_features, + old_num_embeddings=out_features, + ) + # get the length of valid embeddings + tp_rank = dist.get_rank(process_group) + partition_size = self.new_num_embeddings // dist.get_world_size(process_group) + if self.old_num_embeddings >= (tp_rank + 1) * partition_size: + self.num_valid_embeddings_local = partition_size + elif self.old_num_embeddings >= tp_rank * partition_size: + self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size + else: + self.num_valid_embeddings_local = 0 + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + lm_head_linear = VocabParallelLMHead1D( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return lm_head_linear + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + # get forward output + if self.skip_bias_add: + output, bias = super().forward(input_) + else: + output = super().forward(input_) + + # delete the padding of output + if self.gather_output: + output = output[..., : self.old_num_embeddings] + else: + output = output[..., : self.num_valid_embeddings_local] + + # return + if self.skip_bias_add: + return output, bias + return output diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index c4cf3fb85..6d99efc19 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -15,7 +15,14 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup): + def forward( + ctx, + vocab_logits: torch.Tensor, + target: torch.Tensor, + ignore_index: int, + process_group: ProcessGroup, + vocab_size: int, + ): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) @@ -41,15 +48,21 @@ class DistCrossEntropy(Function): vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) # mask the target in the local device - partition_vocab_size = vocab_logits.size()[-1] rank = dist.get_rank(group=process_group) world_size = dist.get_world_size(group=process_group) - global_vocab_size = partition_vocab_size * world_size + if vocab_size == None: + partition_vocab_size = vocab_logits.size()[-1] + global_vocab_size = partition_vocab_size * world_size + else: + global_vocab_size = vocab_size + partition_vocab_size = global_vocab_size // world_size # [down, up) => false, other device and -100 => true delta = (global_vocab_size + world_size - 1) // world_size down_threshold = rank * delta up_threshold = down_threshold + delta + if up_threshold > global_vocab_size: + up_threshold = global_vocab_size mask = (target < down_threshold) | (target >= up_threshold) masked_target = target.clone() - down_threshold masked_target[mask] = 0 @@ -57,7 +70,8 @@ class DistCrossEntropy(Function): # reshape the logits and target # reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the labels to [bath_size * seq_len] - logits_2d = vocab_logits.view(-1, partition_vocab_size) + self_vocab_size = vocab_logits.size()[-1] + logits_2d = vocab_logits.view(-1, self_vocab_size) masked_target_1d = masked_target.view(-1) # extract the x[class] and set the x[other device] to zero @@ -104,10 +118,14 @@ class DistCrossEntropy(Function): grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits.mul_(grad_output.unsqueeze(dim=-1)) - return grad_logits, None, None, None + return grad_logits, None, None, None, None def cross_entropy_1d( - vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None + vocab_logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, + process_group: ProcessGroup = None, + vocab_size: int = None, ) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size) diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index 6c0d83cc7..11ef73538 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -3,7 +3,7 @@ import itertools from abc import ABC, abstractmethod -from typing import List, Union +from typing import List, Optional, Union import torch import torch.nn as nn @@ -20,11 +20,15 @@ from colossalai.tensor.d_tensor import ( is_distributed_tensor, sharded_tensor_to_param, ) +from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor __all__ = ["ParallelModule"] class ParallelModule(nn.Module, ABC): + def __init__(self, **kwargs): + super().__init__() + @abstractmethod def from_native_module( module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None @@ -54,7 +58,7 @@ class ParallelModule(nn.Module, ABC): """ for name, param in self._parameters.items(): if param is not None: - destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars) + destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: @@ -171,3 +175,187 @@ class ParallelModule(nn.Module, ABC): input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child if input_name not in self._modules and input_name not in local_state: unexpected_keys.append(key) + + +class PaddingParallelModule(ParallelModule): + def __init__( + self, + new_num_embeddings: int, + old_num_embeddings: int, + weight: Optional[nn.Parameter], + bias_: Optional[nn.Parameter] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.new_num_embeddings = new_num_embeddings + self.old_num_embeddings = old_num_embeddings + self.weight = weight + self.bias = bias_ + + if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings): + self.resize_embedding_weight() + + if self.bias is not None and not ( + is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings + ): + self.resize_embedding_bias() + + @abstractmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "PaddingParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + raise NotImplementedError + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + param = gather_distributed_param(param, keep_vars=keep_vars) + if is_padded_tensor(param): + param = to_unpadded_tensor(param) + destination[prefix + name] = param.data + + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append( + 'While copying the parameter named "{}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + "received {}".format(key, type(input_param)) + ) + continue + + if is_padded_tensor(param): + input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim) + + if is_distributed_tensor(param): + # shard the input param + device_mesh = get_device_mesh(param) + sharding_spec = get_sharding_spec(param) + sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec) + input_param = sharded_tensor_to_param(sharded_tensor) + elif is_customized_distributed_tensor(param): + input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn) + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) + continue + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + def resize_embedding_weight(self): + self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0) + + def resize_embedding_bias(self): + self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 1306c8aa6..26088569a 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -26,7 +26,6 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d -from ..layer._operation import gather_forward_split_backward logger = logging.get_logger(__name__) @@ -397,13 +396,11 @@ class GPT2PipelineForwards: shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) else: loss = loss_fct(shift_logits, shift_labels) - if not shard_config.parallel_output: - lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) - if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1301,12 +1298,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) - if not shard_config.parallel_output: - lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) - if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 0f1b4ad0a..c3b5426c2 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -316,7 +316,10 @@ class LlamaPipelineForwards: new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) @@ -735,11 +738,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) if not return_dict: diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index d67ab0a3c..e976672bb 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -195,3 +195,12 @@ class Policy(ABC): List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] """ return [] + + def tie_weight_check(self): + input_embedding = self.model.get_input_embeddings() + output_embedding = self.model.get_output_embeddings() + return ( + input_embedding is not None + and output_embedding is not None + and id(input_embedding.weight) == id(output_embedding.weight) + ) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0a61d8cff..d43fc893a 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -37,17 +37,7 @@ class BertPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - # TODO: - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -62,6 +52,13 @@ class BertPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -150,10 +147,6 @@ class BertPolicy(Policy): policy[BertEmbeddings] = ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ), SubModuleReplacementDescription( suffix="dropout", target_module=col_nn.DropoutForReplicatedInput, @@ -168,6 +161,18 @@ class BertPolicy(Policy): target_key=BertModel, ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=embedding_cls, + ) + ], + policy=policy, + target_key=BertEmbeddings, + ) + # optimization configuration # Handle bert layer self.append_or_create_submodule_replacement( @@ -237,8 +242,21 @@ class BertPolicy(Policy): self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="decoder", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=base_policy, + target_key=BertLMPredictionHead, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="decoder", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=base_policy, target_key=BertLMPredictionHead, diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 9be2a1e78..b845e9336 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -17,16 +17,7 @@ class BlipPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - # TODO: - vocab_size = self.model.config.qformer_config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -43,6 +34,13 @@ class BlipPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -202,22 +200,48 @@ class BlipPolicy(Policy): ], ) - policy[OPTForCausalLM] = ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="model.decoder.embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, - ), - ] - ) - policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="model.decoder.embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + ], + policy=policy, + target_key=OPTForCausalLM, + ) + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + ], + policy=policy, + target_key=OPTForCausalLM, + ) + else: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + ], + policy=policy, + target_key=OPTForCausalLM, + ) # optimization configuration # Handle Blip2EncoderLayer layer self.append_or_create_submodule_replacement( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 2becadc3f..953592abc 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -35,16 +35,7 @@ class BloomPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -52,6 +43,13 @@ class BloomPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -112,12 +110,19 @@ class BloomPolicy(Policy): method_replacement={ "build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) }, - sub_module_replacement=[ + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), ], + policy=policy, + target_key=BloomModel, ) # optimization configuration @@ -282,7 +287,21 @@ class BloomForCausalLMPolicy(BloomPolicy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + ), + ), + policy=policy, + target_key=BloomForCausalLM, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), ), policy=policy, target_key=BloomForCausalLM, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index dabc14bff..f205835e7 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -25,20 +25,12 @@ class ChatGLMPolicy(Policy): pass def preprocess(self): - # Resize embedding - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.padded_vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - if self.pipeline_stage_manager is not None: # the batch_size_dim is bounded to Model bsz_dim = 1 setattr(self.model, "batch_size_dim", bsz_dim) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -46,6 +38,13 @@ class ChatGLMPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: if self.model.config.rmsnorm: norm_cls = col_nn.FusedRMSNorm @@ -68,16 +67,6 @@ class ChatGLMPolicy(Policy): sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription( - attribute_replacement={}, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embedding.word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) - ], - ) - policy[GLMBlock] = ModulePolicyDescription( attribute_replacement={ "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads @@ -114,6 +103,19 @@ class ChatGLMPolicy(Policy): ), ], ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="embedding.word_embeddings", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + ], + policy=policy, + target_key=ChatGLMModel, + ) # optimization configuration self.append_or_create_submodule_replacement( description=[ diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index fe61c406f..a2f110a41 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -32,16 +32,7 @@ class FalconPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -58,6 +49,14 @@ class FalconPolicy(Policy): warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") policy = {} + + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_tensor_parallelism: attn_attribute_replacement = { "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -98,12 +97,19 @@ class FalconPolicy(Policy): method_replacement={ "build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) }, - sub_module_replacement=[ + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), ], + policy=policy, + target_key=FalconModel, ) # optimization configuration @@ -232,11 +238,26 @@ class FalconForCausalLMPolicy(FalconPolicy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + ), ), policy=policy, target_key=FalconForCausalLM, ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), + ), + policy=policy, + target_key=FalconForCausalLM, + ) + if self.pipeline_stage_manager: self.set_pipeline_forward( model_cls=FalconForCausalLM, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 380a432dc..98db7b948 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -34,12 +34,7 @@ class GPT2Policy(Policy): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -47,6 +42,13 @@ class GPT2Policy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -73,10 +75,6 @@ class GPT2Policy(Policy): if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), SubModuleReplacementDescription( suffix="drop", target_module=col_nn.DropoutForParallelInput, @@ -137,6 +135,17 @@ class GPT2Policy(Policy): ), ], ) + if embedding_cls is not None: + # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="wte", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=GPT2Model, + ) # optimization configuration self.append_or_create_submodule_replacement( @@ -298,8 +307,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": not self.shard_config.parallel_output}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": False, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ], ) @@ -308,7 +320,19 @@ class GPT2LMHeadModelPolicy(GPT2Policy): addon_module[GPT2LMHeadModel].method_replacement = { "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) } - module_policy.update(addon_module) + else: + addon_module = { + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ] + ) + } + module_policy.update(addon_module) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( @@ -353,13 +377,28 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ] ) } - module_policy.update(addon_module) + else: + addon_module = { + GPT2DoubleHeadsModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ] + ) + } + module_policy.update(addon_module) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index eab4c214a..4b69137a6 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -29,22 +29,21 @@ class GPTJPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel policy = {} + + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") @@ -54,10 +53,6 @@ class GPTJPolicy(Policy): if self.shard_config.enable_tensor_parallelism: policy[GPTJModel] = ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), SubModuleReplacementDescription( suffix="drop", target_module=col_nn.DropoutForParallelInput, @@ -126,6 +121,17 @@ class GPTJPolicy(Policy): ], ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="wte", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=GPTJModel, + ) + # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement( @@ -255,13 +261,28 @@ class GPTJForCausalLMPolicy(GPTJPolicy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ] ) } - policy.update(addon_module) + else: + addon_module = { + GPTJForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ] + ) + } + policy.update(addon_module) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index bb4551b2c..ff686a179 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -6,7 +6,16 @@ import torch.nn as nn from torch import Tensor from torch.nn import Module -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + RMSNorm, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) from ..modeling.llama import ( LlamaPipelineForwards, @@ -26,15 +35,7 @@ class LlamaPolicy(Policy): pass def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -42,6 +43,13 @@ class LlamaPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = FusedRMSNorm else: @@ -167,10 +175,12 @@ class LlamaPolicy(Policy): ], ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=LlamaModel, @@ -327,8 +337,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs={"gather_output": not self.shard_config.parallel_output}, + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ], ) @@ -337,7 +350,19 @@ class LlamaForCausalLMPolicy(LlamaPolicy): new_item[LlamaForCausalLM].method_replacement = { "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) } - policy.update(new_item) + else: + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ], + ) + } + policy.update(new_item) if self.pipeline_stage_manager: # set None as default diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c0b8b3375..b225fd2a9 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -3,7 +3,15 @@ from typing import Dict, Union import torch.nn as nn -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) from ..modeling.mistral import get_mistral_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -16,15 +24,7 @@ class MistralPolicy(Policy): pass def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -32,6 +32,13 @@ class MistralPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( @@ -80,10 +87,12 @@ class MistralPolicy(Policy): ], ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=MistralModel, @@ -146,6 +155,8 @@ class MistralForCausalLMPolicy(MistralPolicy): from transformers import MistralForCausalLM policy = super().module_policy() + if self.pipeline_stage_manager: + warnings.warn("Mistral doesn't support pipeline parallelism now.") if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -153,16 +164,30 @@ class MistralForCausalLMPolicy(MistralPolicy): MistralForCausalLM: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + ), + ) + ] + ) + } + else: + new_item = { + MistralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), ) ] ) } - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") - - policy.update(new_item) + policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 98e584be8..ac78ff6a7 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -5,7 +5,16 @@ from typing import Callable, Dict, List import torch.nn as nn from torch import Tensor, nn -from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedLayerNorm, + LayerNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) from .._utils import getattr_ from ..modeling.jit import get_jit_fused_dropout_add_func @@ -41,16 +50,7 @@ class OPTPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -58,6 +58,13 @@ class OPTPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = FusedLayerNorm else: @@ -68,14 +75,6 @@ class OPTPolicy(Policy): warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ] - ) policy[OPTDecoderLayer] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( @@ -114,6 +113,17 @@ class OPTPolicy(Policy): ], ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=OPTDecoder, + ) + # optimization configuration self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -253,8 +263,20 @@ class OPTForCausalLMPolicy(OPTPolicy): self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True), + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + ), + ), + policy=policy, + target_key=OPTForCausalLM, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), ), policy=policy, target_key=OPTForCausalLM, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 0c8ec15fa..3c7e92b47 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -13,8 +13,11 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, RMSNorm, VocabParallelEmbedding1D, + VocabParallelLMHead1D, ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription @@ -36,16 +39,7 @@ class T5BasePolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -61,6 +55,13 @@ class T5BasePolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = FusedRMSNorm else: @@ -77,10 +78,6 @@ class T5BasePolicy(Policy): suffix="dropout", target_module=DropoutForParallelInput, ), - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ), ] ) policy[T5LayerSelfAttention] = ModulePolicyDescription( @@ -176,6 +173,17 @@ class T5BasePolicy(Policy): ] ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=T5Stack, + ) + # optimization configuration self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -370,11 +378,19 @@ class T5ModelPolicy(T5BasePolicy): policy = super().module_policy() + embedding_cls = None if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="shared", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=T5Model, @@ -406,17 +422,44 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): policy = super().module_policy() + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="shared", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=T5ForConditionalGeneration, + ) + if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ), - ], + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=policy, + target_key=T5ForConditionalGeneration, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), policy=policy, target_key=T5ForConditionalGeneration, ) @@ -467,11 +510,19 @@ class T5EncoderPolicy(T5BasePolicy): policy = super().module_policy() + embedding_cls = None if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="shared", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=T5EncoderModel, diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index c63f6d1cc..0b5114fa6 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -45,11 +45,7 @@ class WhisperPolicy(Policy): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -63,6 +59,13 @@ class WhisperPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -167,13 +170,17 @@ class WhisperPolicy(Policy): ], ) - policy[WhisperDecoder] = ModulePolicyDescription( - sub_module_replacement=[ + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), - ] + ], + policy=policy, + target_key=WhisperDecoder, ) # optimization configuration @@ -280,8 +287,21 @@ class WhisperPolicy(Policy): self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="proj_out", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=base_policy, + target_key=WhisperForConditionalGeneration, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="proj_out", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=base_policy, target_key=WhisperForConditionalGeneration, @@ -526,9 +546,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): - def preprocess(self): - return self.model - def module_policy(self): from transformers import WhisperForAudioClassification diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 7489873c2..963732543 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -42,10 +42,9 @@ class ShardConfig: sequence_parallelism_mode: str = None enable_sequence_overlap: bool = False parallel_output: bool = True + make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) - # TODO padding vocab - # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 667a7b78e..c2cf73181 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -10,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.misc import LayoutException +from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from .sharding_spec import ShardingSpec @@ -607,8 +608,18 @@ class LayoutConverter(metaclass=SingletonMeta): [3.], [3.]]) """ + _, comm_action_sequence = self.layout_converting(source_layout, target_layout) + + target_tensor = tensor for comm_spec in comm_action_sequence: - tensor = comm_spec.covert_spec_to_action(tensor) - tensor.dist_layout = target_layout - return tensor + target_tensor = comm_spec.covert_spec_to_action(target_tensor) + target_tensor.dist_layout = target_layout + + # restore the padding information + if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor): + target_tensor = init_as_padded_tensor( + target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim + ) + + return target_tensor diff --git a/colossalai/tensor/padded_tensor/__init__.py b/colossalai/tensor/padded_tensor/__init__.py new file mode 100644 index 000000000..353ff35f8 --- /dev/null +++ b/colossalai/tensor/padded_tensor/__init__.py @@ -0,0 +1,3 @@ +from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor + +__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"] diff --git a/colossalai/tensor/padded_tensor/api.py b/colossalai/tensor/padded_tensor/api.py new file mode 100644 index 000000000..5b66c016b --- /dev/null +++ b/colossalai/tensor/padded_tensor/api.py @@ -0,0 +1,128 @@ +import torch + + +def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + ptensor._unpad_detach = ptensor.detach + ptensor._unpad_clone = ptensor.clone + + def new_detach(self): + t_ = self._unpad_detach() + t_._padding_dim = self._padding_dim + t_._origin_length = self._origin_length + t_._current_length = self._current_length + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._unpad_clone(*args, **kwargs) + t_._padding_dim = self._padding_dim + t_._origin_length = self._origin_length + t_._current_length = self._current_length + return t_ + + # bind the new methods to the tensor + ptensor.detach = new_detach.__get__(ptensor) + ptensor.clone = new_clone.__get__(ptensor) + return ptensor + + +def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + ptensor.detach = ptensor._unpad_detach + ptensor.clone = ptensor._unpad_clone + + delattr(ptensor, "_unpad_detach") + delattr(ptensor, "_unpad_clone") + + return ptensor + + +def is_padded_tensor(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a padding tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a padding tensor. + """ + return hasattr(tensor, "_padding_dim") + + +def to_padded_tensor( + tensor: torch.Tensor, + current_length: int, + padding_dim: int, +) -> torch.Tensor: + assert ( + padding_dim < tensor.dim() + ), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}" + + if is_padded_tensor(tensor): + return tensor + + origin_length = tensor.shape[padding_dim] + padding_num = current_length - origin_length + padding_data = torch.zeros( + *tensor.shape[:padding_dim], + padding_num, + *tensor.shape[padding_dim + 1 :], + device=tensor.device, + dtype=tensor.dtype, + ) + tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous() + + tensor._padding_dim = padding_dim + tensor._origin_length = origin_length + tensor._current_length = current_length + + _hijack_detach_and_clone(tensor) + + return tensor + + +def to_unpadded_tensor(ptensor: torch.Tensor): + if not is_padded_tensor(ptensor): + return ptensor + + unpad_slices = [slice(None)] * ptensor.dim() + unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length) + ptensor.data = ptensor.data[tuple(unpad_slices)] + + delattr(ptensor, "_padding_dim") + delattr(ptensor, "_origin_length") + delattr(ptensor, "_current_length") + + _hijack_back_detach_and_clone(ptensor) + + return ptensor + + +def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int): + if is_padded_tensor(tensor): + return tensor + + tensor._padding_dim = padding_dim + tensor._origin_length = origin_length + tensor._current_length = current_length + + _hijack_detach_and_clone(tensor) + + return tensor diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index e415b5fc3..bdf7b19f3 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1 rtol=rtol, atol=atol, msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ - dtype: {a.dtype} vs {b.dtype}", + dtype: {a.dtype} vs {b.dtype}", ) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index bc6c9d088..c79422171 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -27,6 +27,12 @@ from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, ) +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import _cast_float, free_storage, is_ddp_ignored @@ -460,6 +466,11 @@ class GeminiDDP(ModelWrapper): record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn ) record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() + if is_padded_tensor(tensor): + record_tensor = init_as_padded_tensor( + record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim + ) + record_tensor = to_unpadded_tensor(record_tensor) assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor @@ -520,6 +531,8 @@ class GeminiDDP(ModelWrapper): # deal with ddp ignored parameters destination[prefix + name] = param if keep_vars else param.detach() else: + if is_padded_tensor(p_mapping[param]): + p_mapping[param] = to_unpadded_tensor(p_mapping[param]) destination[prefix + name] = p_mapping[param] del p_mapping del param_to_save_data @@ -627,6 +640,7 @@ class GeminiDDP(ModelWrapper): list, and will be reported together in :meth:`~torch.nn.Module.load_state_dict` """ + for hook in self._load_state_dict_pre_hooks.values(): hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -647,6 +661,14 @@ class GeminiDDP(ModelWrapper): if state_key in state_dict: input_param = state_dict[state_key] + global_shape = dest_tensor.shape + if source_device_mesh is not None and source_sharding_spec is not None: + global_shape = get_global_shape(dest_tensor) + + if is_padded_tensor(dest_tensor): + padding_dim = dest_tensor._padding_dim + input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim) + if source_device_mesh is not None and source_sharding_spec is not None: input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) elif shard_fn is not None and gather_fn is not None: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 18367af59..ae02fe297 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -21,12 +21,19 @@ from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, get_device_mesh, + get_global_shape, get_sharding_spec, init_as_dtensor, init_tensor_as_customization_distributed, is_customized_distributed_tensor, is_distributed_tensor, ) +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.utils import disposable, is_ddp_ignored from .chunk import Chunk, ChunkManager @@ -106,7 +113,7 @@ class GeminiOptimizer(OptimizerWrapper): max_norm: float = 0.0, norm_type: float = 2.0, tp_group: ProcessGroup = None, - optimizer_params_info=None, + params_info=None, verbose: bool = False, **defaults: Any, ): @@ -124,7 +131,7 @@ class GeminiOptimizer(OptimizerWrapper): self.clipping_flag = max_norm > 0.0 self.max_norm = max_norm self.tp_group = tp_group - self.optimizer_params_info = optimizer_params_info + self.params_info = params_info self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.verbose = verbose @@ -459,7 +466,7 @@ class GeminiOptimizer(OptimizerWrapper): is_customized_distributed = is_customized_distributed_tensor(param) shard_spec = get_sharding_spec(param) if is_dtensor else None device_mesh = get_device_mesh(param) if is_dtensor else None - global_shape = self.optimizer_params_info["id2shape"][param_id] + global_shape = self.params_info["id2shape"][param_id] # If the chunk is kept gathered, # the parameters are treated the same as that of those in strict DDP during training. @@ -477,6 +484,7 @@ class GeminiOptimizer(OptimizerWrapper): else: state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() if is_dtensor: + global_shape = get_global_shape(param) state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) state_tensor = init_as_dtensor( state_tensor, @@ -490,8 +498,13 @@ class GeminiOptimizer(OptimizerWrapper): state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() - - collected_states[state_name] = state_tensor.reshape(global_shape) + state_tensor = state_tensor.reshape(global_shape) + if is_padded_tensor(param): + state_tensor = init_as_padded_tensor( + state_tensor, param._current_length, param._origin_length, param._padding_dim + ) + state_tensor = to_unpadded_tensor(state_tensor) + collected_states[state_name] = state_tensor return collected_states # Check whether the param with given id is managed by current process. @@ -535,6 +548,7 @@ class GeminiOptimizer(OptimizerWrapper): if state_tensor.numel() == param.numel(): collected_states[state_name] = torch.reshape(state_tensor, param.shape) if is_dtensor: + global_shape = get_global_shape(param) state_tensor = state_tensor.to(param.device) state_tensor = init_as_dtensor( state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape @@ -545,6 +559,11 @@ class GeminiOptimizer(OptimizerWrapper): state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() + if is_padded_tensor(param): + state_tensor = init_as_padded_tensor( + state_tensor, param._current_length, param._origin_length, param._padding_dim + ) + state_tensor = to_unpadded_tensor(state_tensor) return collected_states @@ -698,7 +717,7 @@ class GeminiOptimizer(OptimizerWrapper): Load saved optimizer states into parameter with given id. """ - def cast(param, state_range, value, key=None): + def cast(param, state_range, value, global_shape, origin_shape, key=None): """ Make a copy of the needed segment of value and cast it to device of param. """ @@ -714,7 +733,14 @@ class GeminiOptimizer(OptimizerWrapper): ) if is_dtensor: - value = torch.reshape(value, global_shape) + global_shape = get_global_shape(real_param) + + if is_padded_tensor(real_param): + value = torch.reshape(value, origin_shape) + padding_dim = real_param._padding_dim + value = to_padded_tensor(value, global_shape[padding_dim], padding_dim) + + if is_dtensor: value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) elif is_customized_distributed: value = torch.reshape(value, global_shape) @@ -737,10 +763,11 @@ class GeminiOptimizer(OptimizerWrapper): is_customized_distributed = is_customized_distributed_tensor(real_param) shard_spec = get_sharding_spec(real_param) if is_dtensor else None device_mesh = get_device_mesh(real_param) if is_dtensor else None - global_shape = self.optimizer_params_info["id2shape"][param_id] + global_shape = self.params_info["id2shape"][param_id] + origin_shape = global_shape for k, v in saved_states.items(): - updated_states[k] = cast(fake_param, state_range, v, k) + updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k) del v # clean loaded states self.optim.state[fake_param].update(updated_states) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index d8a625b98..4753ab637 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -81,8 +81,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf optimizer.backward(loss) optimizer.step() - for group in optimizer.param_groups: - group["lr"] = 0.1 + optimizer.zero_grad() with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index b23a44f2d..91cc1a987 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -21,7 +21,7 @@ def check_vocab_embedding_1d(lazy_init: bool): dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) - assert dist_embedding_1d.num_embeddings == 64 + assert dist_embedding_1d.num_embeddings == 128 assert dist_embedding_1d.embedding_dim == 32 assert embedding_copy.weight is dist_embedding_1d.weight diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d5fc2c30f..a77ba39a1 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -14,12 +14,14 @@ from torch.testing import assert_close from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule +from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer._utils import getattr_ from colossalai.shardformer.policies.auto_policy import Policy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor def build_model( @@ -247,11 +249,10 @@ def check_weight( continue if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): - sharded_weight_list = [ - torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) - ] - dist.all_gather(sharded_weight_list, sharded_weight, tp_group) - sharded_weight = torch.cat(sharded_weight_list, dim=dim) + sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False) + + if is_padded_tensor(sharded_weight): + sharded_weight = to_unpadded_tensor(sharded_weight) if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 9b22d54d7..a6fe2dd39 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -73,7 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if test_config["precision"] == "fp32": - atol, rtol = 5e-4, 1e-3 + # TODO he precision in weight checking is too significant. + atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): diff --git a/tests/test_tensor/test_padded_tensor.py b/tests/test_tensor/test_padded_tensor.py new file mode 100644 index 000000000..31a267c15 --- /dev/null +++ b/tests/test_tensor/test_padded_tensor.py @@ -0,0 +1,46 @@ +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global +from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_padded_tensor(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + original_tensor = torch.rand(32, 64).to("cuda") + + device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) + target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) + d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec) + + padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0) + assert padded_tensor.dist_layout == d_tensor.dist_layout + + tensor_copy = padded_tensor.clone() + assert is_padded_tensor(tensor_copy) + assert is_distributed_tensor(tensor_copy) + + tensor_detached = padded_tensor.detach() + assert is_padded_tensor(tensor_detached) + assert is_distributed_tensor(tensor_detached) + + unpadded_tensor = to_unpadded_tensor(padded_tensor) + assert unpadded_tensor.shape == d_tensor.shape + assert is_distributed_tensor(unpadded_tensor) + + global_tensor = to_global(unpadded_tensor) + assert global_tensor.shape == original_tensor.shape + + +@rerun_if_address_is_in_use() +def test_padded_tensor(): + world_size = 4 + spawn(check_padded_tensor, world_size) + + +if __name__ == "__main__": + test_padded_tensor() From d83c633ca63c4eef49f3473aa998515fa5ca573f Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 18 Apr 2024 18:15:50 +0800 Subject: [PATCH 04/28] [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606) * fix no pad token bug * fixed some auto parallel codegen bug, but might not run on torch 2.1 --------- Co-authored-by: Edenzzzz --- colossalai/_analyzer/fx/codegen.py | 2 +- colossalai/auto_parallel/offload/base_offload_module.py | 2 +- colossalai/auto_parallel/offload/region.py | 3 ++- colossalai/autochunk/autochunk_codegen.py | 2 +- colossalai/fx/codegen/activation_checkpoint_codegen.py | 2 +- examples/language/gpt/hybridparallelism/data.py | 2 ++ 6 files changed, 8 insertions(+), 5 deletions(-) diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py index cd244b22c..68a27d919 100644 --- a/colossalai/_analyzer/fx/codegen.py +++ b/colossalai/_analyzer/fx/codegen.py @@ -246,7 +246,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, @compatibility(is_backward_compatible=True) class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] globals_: Dict[str, Any] = {} diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index f5e8e31f5..60de7743a 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from colossalai.utils import _cast_float -from colossalai.zero.legacy.gemini.tensor_utils import free_storage +from colossalai.utils.common import free_storage from .region_manager import RegionManager from .util import GlobalRuntimeInfo diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py index ea92c714c..a9f6f4c18 100644 --- a/colossalai/auto_parallel/offload/region.py +++ b/colossalai/auto_parallel/offload/region.py @@ -3,7 +3,8 @@ from typing import Dict, List, Tuple import torch from torch.fx import Node -from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage +from colossalai.utils.common import free_storage +from colossalai.zero.gemini.chunk.chunk import alloc_storage class Region: diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 9571fa2c1..07dbf8a79 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -372,7 +372,7 @@ if AUTOCHUNK_AVAILABLE: if print_progress: get_logger().info("AutoChunk start codegen") - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] globals_: Dict[str, Any] = {} diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index dfb5754d7..28451bdd1 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -625,7 +625,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if CODEGEN_AVAILABLE: class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] globals_: Dict[str, Any] = {} diff --git a/examples/language/gpt/hybridparallelism/data.py b/examples/language/gpt/hybridparallelism/data.py index ef51f938d..e5dc882bc 100644 --- a/examples/language/gpt/hybridparallelism/data.py +++ b/examples/language/gpt/hybridparallelism/data.py @@ -62,6 +62,8 @@ class GLUEDataBuilder: self.text_fields = self.task_text_field_map[task_name] self.num_labels = self.glue_task_num_labels[task_name] self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + if not getattr(self.tokenizer, "pad_token", None): + self.tokenizer.pad_token = self.tokenizer._eos_token self.setup() def setup(self): From e094933da1d0a574eda105ab6ec0f171d8ddaebb Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 22 Apr 2024 11:25:39 +0800 Subject: [PATCH 05/28] [shardformer] fix pipeline grad ckpt (#5620) * [shardformer] fix pipeline grad ckpt --- colossalai/shardformer/shard/grad_ckpt_config.py | 6 ++++++ colossalai/shardformer/shard/shard_config.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py index 9c6c2b54e..9fc857d19 100644 --- a/colossalai/shardformer/shard/grad_ckpt_config.py +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -22,6 +22,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. """ + """ Args: gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. @@ -49,6 +50,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): num_stages: Optional[int] = None num_model_chunks: Optional[int] = None num_model_layers: Optional[int] = None + num_layers_per_stage: Optional[List[int]] = None num_ckpt_layers_per_stage: Optional[List[int]] = None def __post_init__(self): @@ -70,6 +72,10 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): def _enable_gradient_checkpointing_ratio(self) -> bool: return self.gradient_checkpointing_ratio is not None + @property + def _customize_num_layers_per_stage(self) -> bool: + return self.num_layers_per_stage is not None and self.num_model_layers is not None + @property def _enable_customized_ckpt_layers_per_stage(self) -> bool: return self.num_ckpt_layers_per_stage is not None diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 963732543..597dd9c26 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup from colossalai.pipeline.stage_manager import PipelineStageManager -from .grad_ckpt_config import GradientCheckpointConfig +from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig __all__ = ["ShardConfig"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @@ -30,6 +30,7 @@ class ShardConfig: gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ + tensor_parallel_process_group: Optional[ProcessGroup] = None sequence_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None @@ -104,6 +105,16 @@ class ShardConfig: else: self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) + if ( + self.pipeline_stage_manager is not None + and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig) + and self.gradient_checkpoint_config._customize_num_layers_per_stage + ): + self.pipeline_stage_manager.set_distribution_config( + self.gradient_checkpoint_config.num_model_layers, + self.gradient_checkpoint_config.num_layers_per_stage, + ) + def _turn_on_all_optimization(self): """ Turn on all optimization. From 862fbaaa626f091c963ae41476607e1c5cec759c Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 23 Apr 2024 13:54:05 +0800 Subject: [PATCH 06/28] [Feature] Support LLaMA-3 CPT and ST (#5619) * support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- applications/Colossal-LLaMA-2/version.txt | 1 - .../README.md | 30 +++++----- .../colossal_llama}/__init__.py | 0 .../colossal_llama}/dataset/__init__.py | 0 .../colossal_llama}/dataset/conversation.py | 14 ++++- .../colossal_llama}/dataset/loader.py | 0 .../dataset/spliced_and_tokenized_dataset.py | 3 +- .../colossal_llama}/model/init_model.py | 0 .../tokenizer/init_tokenizer.py | 0 .../colossal_llama}/utils/__init__.py | 0 .../colossal_llama}/utils/ckpt_io.py | 0 .../utils/flash_attention_patch.py | 0 .../colossal_llama}/utils/froze.py | 0 .../colossal_llama}/utils/neftune_patch.py | 0 .../utils/stream_chat_patch.py | 0 .../docs/example_13b.md | 0 .../docs/example_7b.md | 0 .../hostfile.example | 0 .../inference_example.py | 2 +- .../prepare_pretrain_dataset.py | 41 +++++--------- .../prepare_sft_dataset.py | 55 +++++++++---------- .../requirements.txt | 9 +-- .../stream_chat_example.py | 2 +- .../train.example.sh | 0 .../train.py | 16 +++--- .../train_sft.example.sh | 0 applications/Colossal-LLaMA/version.txt | 1 + applications/README.md | 2 +- 28 files changed, 89 insertions(+), 87 deletions(-) delete mode 100644 applications/Colossal-LLaMA-2/version.txt rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/README.md (97%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/__init__.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/dataset/__init__.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/dataset/conversation.py (86%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/dataset/loader.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/dataset/spliced_and_tokenized_dataset.py (99%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/model/init_model.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/tokenizer/init_tokenizer.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/__init__.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/ckpt_io.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/flash_attention_patch.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/froze.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/neftune_patch.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/stream_chat_patch.py (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/docs/example_13b.md (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/docs/example_7b.md (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/hostfile.example (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/inference_example.py (97%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/prepare_pretrain_dataset.py (80%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/prepare_sft_dataset.py (74%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/requirements.txt (65%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/stream_chat_example.py (97%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/train.example.sh (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/train.py (96%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/train_sft.example.sh (100%) create mode 100644 applications/Colossal-LLaMA/version.txt diff --git a/applications/Colossal-LLaMA-2/version.txt b/applications/Colossal-LLaMA-2/version.txt deleted file mode 100644 index 8acdd82b7..000000000 --- a/applications/Colossal-LLaMA-2/version.txt +++ /dev/null @@ -1 +0,0 @@ -0.0.1 diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA/README.md similarity index 97% rename from applications/Colossal-LLaMA-2/README.md rename to applications/Colossal-LLaMA/README.md index 1377e1fac..93ba58ac5 100644 --- a/applications/Colossal-LLaMA-2/README.md +++ b/applications/Colossal-LLaMA/README.md @@ -1,6 +1,6 @@

- +Colossal-LLaMA

@@ -47,6 +47,7 @@ - [Citations](#citations) ## News +* [2024/4] Support continual pre-training and supervised fine-tuning of LLaMA-3. * [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b). [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) [[blog]](https://hpc-ai.com/blog/colossal-llama-2-13b) @@ -289,7 +290,7 @@ Here is details about CLI arguments: #### 1. Install required packages ``` -cd Colossal-LLaMA-2 +cd Colossal-LLaMA pip install -r requirements.txt ``` #### 2. Install `xentropy`, `layer_norm` and `rotary` @@ -314,7 +315,7 @@ Initialize new tokenizer with additional Chinese tokens. Additional Chinese toke Command to initialize new tokenizer: ```bash export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python' -python colossal_llama2/tokenizer/init_tokenizer.py \ +python colossal_llama/tokenizer/init_tokenizer.py \ --source_tokenizer_dir "" \ --target_tokenizer_dir "" \ --expand_tokens_file ".jsonl" @@ -328,7 +329,7 @@ Here is details about CLI arguments: Initialize the new model checkpoint by calculating the mean values from the original model checkpoint. Command to initialize new model checkpoint: ```bash -python colossal_llama2/model/init_model.py \ +python colossal_llama/model/init_model.py \ --source_model_and_tokenizer_path "" \ --target_tokenizer_path "" \ --target_model_path "" @@ -362,18 +363,17 @@ Command to convert jsonl dataset to arrow format: python prepare_pretrain_dataset.py \ --data_input_dirs ",," \ --tokenizer_dir "" \ - --data_cache_dir "jsonl_to_arrow_cache" \ - --data_jsonl_output_dir "spliced_tokenized_output_jsonl" \ - --data_arrow_output_dir "spliced_tokenized_output_arrow" \ + --data_output_dirs "spliced tokenized output" \ --max_length 4096 \ --num_spliced_dataset_bins 10 ``` Here is details about CLI arguments: * Source data directory: `data_input_dirs`. Each `` can have multiple file in `jsonl` format. * Tokenizer directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format. -* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally. -* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format. -* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly. +* Data output directory: `data_output_dirs`. Directory to store preprocessed output, including three sub-directories: + * `cache`: Directory to store Hugging Face data cache. + * `jsonl`: Output directory to store converted dataset in jsonl format. + * `arrow`: Output directory to store converted dataset in arrow format, which can be used for training directly. * Max length: `max_length`. Max length of spliced samples. Default value is 4096. * Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training. @@ -392,13 +392,15 @@ Command to convert jsonl dataset to arrow format is similar to the command in [3 python prepare_sft_dataset.py.py \ --data_input_dirs ",," \ --tokenizer_dir "" \ - --data_cache_dir "jsonl_to_arrow_cache" \ - --data_jsonl_output_dir "spliced_tokenized_output_jsonl" \ - --data_arrow_output_dir "spliced_tokenized_output_arrow" \ + --data_output_dirs "spliced tokenized output" \ --max_length 4096 \ - --num_spliced_dataset_bins 10 + --num_spliced_dataset_bins 10 \ + --llama_version 3 ``` +Additional CLI arguments: +* LLaMA verison: `llama_version`. Specify the LLaMA version. + #### 4. Command Line Arguments for Training ##### 4.1 Arguments for Pretraining diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py b/applications/Colossal-LLaMA/colossal_llama/__init__.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/__init__.py rename to applications/Colossal-LLaMA/colossal_llama/__init__.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py b/applications/Colossal-LLaMA/colossal_llama/dataset/__init__.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py rename to applications/Colossal-LLaMA/colossal_llama/dataset/__init__.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py b/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py similarity index 86% rename from applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py rename to applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py index be27ff7bc..8ec9c848b 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py @@ -83,7 +83,7 @@ class Conversation: } -conv = Conversation( +LLaMA2_Conv = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", roles=("Human", "Assistant"), @@ -93,4 +93,14 @@ conv = Conversation( seps=["", ""], ) -default_conversation = conv +LLaMA3_Conv = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN, + seps=["<|begin_of_text|>", "<|end_of_text|>"], +) + +default_conversation = LLaMA3_Conv diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py rename to applications/Colossal-LLaMA/colossal_llama/dataset/loader.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py similarity index 99% rename from applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py rename to applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py index 8314941ba..30122d283 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple, Union from datasets import dataset_dict from torch.utils.data import ConcatDataset, Dataset, IterableDataset +from transformers import AutoTokenizer from transformers.models.llama.tokenization_llama import LlamaTokenizer from transformers.tokenization_utils import PreTrainedTokenizer @@ -71,7 +72,7 @@ def supervised_tokenize_pretrain( def supervised_tokenize_sft( data_point: Dict[str, str], - tokenizer: LlamaTokenizer, + tokenizer: AutoTokenizer, conversation_template: Conversation = default_conversation, ignore_index: int = None, max_length: int = 4096, diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py b/applications/Colossal-LLaMA/colossal_llama/model/init_model.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py rename to applications/Colossal-LLaMA/colossal_llama/model/init_model.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py b/applications/Colossal-LLaMA/colossal_llama/tokenizer/init_tokenizer.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py rename to applications/Colossal-LLaMA/colossal_llama/tokenizer/init_tokenizer.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py b/applications/Colossal-LLaMA/colossal_llama/utils/__init__.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py rename to applications/Colossal-LLaMA/colossal_llama/utils/__init__.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py b/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py rename to applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA/colossal_llama/utils/flash_attention_patch.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py rename to applications/Colossal-LLaMA/colossal_llama/utils/flash_attention_patch.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py b/applications/Colossal-LLaMA/colossal_llama/utils/froze.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py rename to applications/Colossal-LLaMA/colossal_llama/utils/froze.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA/colossal_llama/utils/neftune_patch.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py rename to applications/Colossal-LLaMA/colossal_llama/utils/neftune_patch.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py b/applications/Colossal-LLaMA/colossal_llama/utils/stream_chat_patch.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py rename to applications/Colossal-LLaMA/colossal_llama/utils/stream_chat_patch.py diff --git a/applications/Colossal-LLaMA-2/docs/example_13b.md b/applications/Colossal-LLaMA/docs/example_13b.md similarity index 100% rename from applications/Colossal-LLaMA-2/docs/example_13b.md rename to applications/Colossal-LLaMA/docs/example_13b.md diff --git a/applications/Colossal-LLaMA-2/docs/example_7b.md b/applications/Colossal-LLaMA/docs/example_7b.md similarity index 100% rename from applications/Colossal-LLaMA-2/docs/example_7b.md rename to applications/Colossal-LLaMA/docs/example_7b.md diff --git a/applications/Colossal-LLaMA-2/hostfile.example b/applications/Colossal-LLaMA/hostfile.example similarity index 100% rename from applications/Colossal-LLaMA-2/hostfile.example rename to applications/Colossal-LLaMA/hostfile.example diff --git a/applications/Colossal-LLaMA-2/inference_example.py b/applications/Colossal-LLaMA/inference_example.py similarity index 97% rename from applications/Colossal-LLaMA-2/inference_example.py rename to applications/Colossal-LLaMA/inference_example.py index 8d301616d..0369d9c0a 100644 --- a/applications/Colossal-LLaMA-2/inference_example.py +++ b/applications/Colossal-LLaMA/inference_example.py @@ -1,7 +1,7 @@ import argparse import torch -from colossal_llama2.dataset.conversation import default_conversation +from colossal_llama.dataset.conversation import default_conversation from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.logging import get_dist_logger diff --git a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py b/applications/Colossal-LLaMA/prepare_pretrain_dataset.py similarity index 80% rename from applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py rename to applications/Colossal-LLaMA/prepare_pretrain_dataset.py index cb578b5f6..9642159aa 100644 --- a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py +++ b/applications/Colossal-LLaMA/prepare_pretrain_dataset.py @@ -11,12 +11,12 @@ import os import time from multiprocessing import cpu_count -from colossal_llama2.dataset.spliced_and_tokenized_dataset import ( +from colossal_llama.dataset.spliced_and_tokenized_dataset import ( ClosedToConstantLengthSplicedDataset, supervised_tokenize_pretrain, ) from datasets import dataset_dict, load_dataset -from transformers.models.llama.tokenization_llama import LlamaTokenizer +from transformers import AutoTokenizer from colossalai.logging import get_dist_logger @@ -35,35 +35,24 @@ def main(): parser.add_argument( "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer" ) - parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory") - parser.add_argument( - "--data_jsonl_output_dir", - type=str, - default="jsonl_output", - help="Output directory of spliced dataset with jsonl format", - ) - parser.add_argument( - "--data_arrow_output_dir", - type=str, - default="arrow_output", - help="Output directory of spliced dataset with arrow format", - ) - parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence") + parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory") + parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence") parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins") args = parser.parse_args() if args.num_spliced_dataset_bins >= 100000: raise ValueError("Too many spliced divisions, must be smaller than 100000") - assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}" - assert not os.path.exists( - args.data_jsonl_output_dir - ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}" - assert not os.path.exists( - args.data_arrow_output_dir - ), f"Find existed arrow data output dir {args.data_arrow_output_dir}" - os.makedirs(args.data_jsonl_output_dir) - os.makedirs(args.data_arrow_output_dir) + args.data_cache_dir = os.path.join(args.data_output_dirs, "cache") + args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl") + args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow") + + if not os.path.exists(args.data_cache_dir): + os.makedirs(args.data_cache_dir) + if not os.path.exists(args.data_jsonl_output_dir): + os.makedirs(args.data_jsonl_output_dir) + if not os.path.exists(args.data_arrow_output_dir): + os.makedirs(args.data_arrow_output_dir) # Prepare to all input datasets input_data_paths = [] @@ -86,7 +75,7 @@ def main(): train_splits.append(f"train[{start}%:{end}%]") # Prepare to the tokenizer. - tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) tokenizer.add_bos_token = False tokenizer.add_eos_token = False if tokenizer.pad_token is None: diff --git a/applications/Colossal-LLaMA-2/prepare_sft_dataset.py b/applications/Colossal-LLaMA/prepare_sft_dataset.py similarity index 74% rename from applications/Colossal-LLaMA-2/prepare_sft_dataset.py rename to applications/Colossal-LLaMA/prepare_sft_dataset.py index 6d19cbd72..be5f9bcca 100644 --- a/applications/Colossal-LLaMA-2/prepare_sft_dataset.py +++ b/applications/Colossal-LLaMA/prepare_sft_dataset.py @@ -10,10 +10,10 @@ import math import os from multiprocessing import cpu_count -from colossal_llama2.dataset.conversation import default_conversation -from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft +from colossal_llama.dataset.conversation import default_conversation +from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft from datasets import dataset_dict, load_dataset -from transformers.models.llama.tokenization_llama import LlamaTokenizer +from transformers import AddedToken, AutoTokenizer from colossalai.logging import get_dist_logger @@ -32,35 +32,25 @@ def main(): parser.add_argument( "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer" ) - parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory") - parser.add_argument( - "--data_jsonl_output_dir", - type=str, - default="jsonl_output", - help="Output directory of spliced dataset with jsonl format", - ) - parser.add_argument( - "--data_arrow_output_dir", - type=str, - default="arrow_output", - help="Output directory of spliced dataset with arrow format", - ) - parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence") + parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory") + parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence") parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins") + parser.add_argument("--llama_version", type=int, default=3, help="LLaMA version") args = parser.parse_args() if args.num_spliced_dataset_bins >= 100000: raise ValueError("Too many spliced divisions, must be smaller than 100000") - assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}" - assert not os.path.exists( - args.data_jsonl_output_dir - ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}" - assert not os.path.exists( - args.data_arrow_output_dir - ), f"Find existed arrow data output dir {args.data_arrow_output_dir}" - os.makedirs(args.data_jsonl_output_dir) - os.makedirs(args.data_arrow_output_dir) + args.data_cache_dir = os.path.join(args.data_output_dirs, "cache") + args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl") + args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow") + + if not os.path.exists(args.data_cache_dir): + os.makedirs(args.data_cache_dir) + if not os.path.exists(args.data_jsonl_output_dir): + os.makedirs(args.data_jsonl_output_dir) + if not os.path.exists(args.data_arrow_output_dir): + os.makedirs(args.data_arrow_output_dir) # Prepare to all input datasets input_data_paths = [] @@ -83,11 +73,20 @@ def main(): train_splits.append(f"train[{start}%:{end}%]") # Prepare to the tokenizer. - tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) + + # Fix split issue: https://github.com/huggingface/transformers/issues/23833 + if args.llama_version == 2: + tokenizer.add_tokens(AddedToken("", normalized=False, special=True), special_tokens=True) + tokenizer.add_bos_token = False tokenizer.add_eos_token = False if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.unk_token + if tokenizer.unk_token is not None: + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.unk_token = tokenizer.eos_token list_dataset = load_dataset( path="json", diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA/requirements.txt similarity index 65% rename from applications/Colossal-LLaMA-2/requirements.txt rename to applications/Colossal-LLaMA/requirements.txt index 5cdb8e7f3..809a942ac 100644 --- a/applications/Colossal-LLaMA-2/requirements.txt +++ b/applications/Colossal-LLaMA/requirements.txt @@ -1,9 +1,10 @@ -torch<2.0.0, >=1.12.1 -packaging==23.1 -colossalai==0.3.5 +torch==2.1.2 +huggingface-hub +packaging==24.0 +colossalai==0.3.6 autoflake==2.2.1 black==23.9.1 -transformers==4.33.3 +transformers==4.34.1 tensorboard==2.14.0 six==1.16.0 datasets diff --git a/applications/Colossal-LLaMA-2/stream_chat_example.py b/applications/Colossal-LLaMA/stream_chat_example.py similarity index 97% rename from applications/Colossal-LLaMA-2/stream_chat_example.py rename to applications/Colossal-LLaMA/stream_chat_example.py index 4c0d1fe2a..9a353b473 100644 --- a/applications/Colossal-LLaMA-2/stream_chat_example.py +++ b/applications/Colossal-LLaMA/stream_chat_example.py @@ -1,6 +1,6 @@ import argparse -from colossal_llama2.utils.stream_chat_patch import streaming_chat +from colossal_llama.utils.stream_chat_patch import streaming_chat from transformers import AutoModelForCausalLM, AutoTokenizer SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA/train.example.sh similarity index 100% rename from applications/Colossal-LLaMA-2/train.example.sh rename to applications/Colossal-LLaMA/train.example.sh diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA/train.py similarity index 96% rename from applications/Colossal-LLaMA-2/train.py rename to applications/Colossal-LLaMA/train.py index d97da61e4..dcd7be9f4 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -12,18 +12,18 @@ from contextlib import nullcontext import torch import torch.distributed as dist -from colossal_llama2.dataset.loader import ( +from colossal_llama.dataset.loader import ( DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset, ) -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.froze import freeze_non_embeds_parameters -from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune +from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama.utils.flash_attention_patch import replace_with_flash_attention +from colossal_llama.utils.froze import freeze_non_embeds_parameters +from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import LlamaForCausalLM, LlamaTokenizer +from transformers import AutoTokenizer, LlamaForCausalLM import colossalai from colossalai.accelerator import get_accelerator @@ -89,7 +89,7 @@ def main() -> None: parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("--max_length", type=int, default=4096, help="Model max length") + parser.add_argument("--max_length", type=int, default=8192, help="Model max length") parser.add_argument( "--mixed_precision", type=str, @@ -196,7 +196,7 @@ def main() -> None: # ====================================================== # Initialize Tokenizer, Dataset, Collator and Dataloader # ====================================================== - tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) + tokenizer = AutoTokenizer.from_pretrained(args.pretrained) if args.pad_token == "eos": tokenizer.pad_token = tokenizer.eos_token elif args.pad_token == "unk": diff --git a/applications/Colossal-LLaMA-2/train_sft.example.sh b/applications/Colossal-LLaMA/train_sft.example.sh similarity index 100% rename from applications/Colossal-LLaMA-2/train_sft.example.sh rename to applications/Colossal-LLaMA/train_sft.example.sh diff --git a/applications/Colossal-LLaMA/version.txt b/applications/Colossal-LLaMA/version.txt new file mode 100644 index 000000000..3eefcb9dd --- /dev/null +++ b/applications/Colossal-LLaMA/version.txt @@ -0,0 +1 @@ +1.0.0 diff --git a/applications/README.md b/applications/README.md index 120767d5c..e7c23c7e9 100644 --- a/applications/README.md +++ b/applications/README.md @@ -5,7 +5,7 @@ This directory contains the applications that are powered by Colossal-AI. The list of applications include: - [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models -- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2. +- [X] [Colossal-LLaMA](./Colossal-LLaMA/): Continual Pre-training and Supervisied Fine-tuning of LLaMA2 / LLaMA3. - [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs. - [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF. - [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters. From 4de4e318185513cab089d2fc28ee5798a5099fe0 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 23 Apr 2024 14:12:20 +0800 Subject: [PATCH 07/28] [exampe] update llama example (#5626) * [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readme --- colossalai/booster/plugin/gemini_plugin.py | 1 + .../booster/plugin/hybrid_parallel_plugin.py | 28 +- examples/language/llama2/README.md | 117 +------ examples/language/llama2/attn.py | 1 - examples/language/llama2/benchmark.py | 62 +++- examples/language/llama2/finetune.py | 313 ----------------- examples/language/llama2/pretrain.py | 328 ------------------ examples/language/llama2/requirements.txt | 5 +- 8 files changed, 72 insertions(+), 783 deletions(-) delete mode 120000 examples/language/llama2/attn.py delete mode 100644 examples/language/llama2/finetune.py delete mode 100644 examples/language/llama2/pretrain.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 442ac4a8d..a67ca18a3 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -424,6 +424,7 @@ class GeminiPlugin(DPPluginBase): ) self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None + self.dp_size = self.zero_size * self.extra_dp_size self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8d12eb806..95fb2def1 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -34,7 +34,6 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase -DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3 SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} @@ -987,6 +986,7 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, make_vocab_size_divisible_by: int = 64, + dp_outside: bool = True, ) -> None: super().__init__() assert ( @@ -1034,7 +1034,12 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + if dp_outside: + self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + else: + self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None self.custom_policy = custom_policy @@ -1048,7 +1053,7 @@ class HybridParallelPlugin(PipelinePluginBase): assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager( self.pg_mesh, - pipeline_axis=PP_AXIS, + pipeline_axis=self.pp_axis, enable_interleave=pp_style == "interleaved", num_model_chunks=num_model_chunks, ) @@ -1072,13 +1077,13 @@ class HybridParallelPlugin(PipelinePluginBase): else: raise NotImplementedError() - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) - self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: - self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: - self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, @@ -1169,7 +1174,7 @@ class HybridParallelPlugin(PipelinePluginBase): and self.sequence_parallelism_mode == "all_to_all" ) if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) + dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) else: dp_group = self.dp_group model = HybridParallelModule( @@ -1317,7 +1322,10 @@ class HybridParallelPlugin(PipelinePluginBase): _kwargs = kwargs.copy() distributed_sampler_cls = distributed_sampler_cls or DistributedSampler sampler = distributed_sampler_cls( - dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + dataset, + num_replicas=self.pg_mesh.size(self.dp_axis), + rank=self.pg_mesh.coordinate(self.dp_axis), + shuffle=shuffle, ) # Deterministic dataloader diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md index 068f15cbb..11b2ee511 100644 --- a/examples/language/llama2/README.md +++ b/examples/language/llama2/README.md @@ -1,4 +1,4 @@ -# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models +# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models ### LLaMA2

@@ -16,38 +16,10 @@ - 65-billion-parameter large model pretraining accelerated by 38% [[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining) -## Dataset - -Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed. - -A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample). - -RedPajama-Data-1T consists of seven data slices: - -| | RedPajama | LLaMA | -|---------------|--------------|---------------| -| CommonCrawl | 878 billion | 852 billion | -| C4 | 175 billion | 190 billion | -| Github | 59 billion | 100 billion | -| Books | 26 billion | 25 billion | -| ArXiv | 28 billion | 33 billion | -| Wikipedia | 24 billion | 25 billion | -| StackExchange | 20 billion | 27 billion | -| Total | 1.2 trillion | 1.25 trillion | - -## Training - -We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps. - -| params | learning rate | batch size | -|--------|---------------|------------| -| 6.7B | 3.0e-4 | 4M | -| 13.0B | 3.0e-4 | 4M | -| 32.5B | 1.5e-4 | 4M | -| 65.2B | 1.5e-4 | 4M | - ## Usage +> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA). + ### 1. Installation Please install the latest ColossalAI from source. @@ -62,52 +34,6 @@ Then install other dependencies. pip install -r requirements.txt ``` -Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention. - -### 2. Download the dataset - -The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. - -### 3. Command line arguments - -Yon can use colossalai run to launch multi-nodes training: -```bash -colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ -pretrain.py --OTHER_CONFIGURATIONS -``` - -Here is a sample hostfile: - -```text -hostname1 -hostname2 -hostname3 -hostname4 -``` - -Make sure master node can access all nodes (including itself) by ssh without password. - -Here is details about CLI arguments: - -- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2. -- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). -- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama. -- Number of epochs: `-e`, `--num_epochs`. The default value is 1. -- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. -- Learning rate: `--lr`. The default value is 3e-4. -- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. -- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000. -- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. -- Max length: `-l`, `--max_length`. The default value is 4096. -- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. -- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. -- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`. -- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. -- Gradient clipping: `--gradient_clipping`. The default value is 1.0. -- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. -- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. - - ### 4. Shell Script Examples For your convenience, we provide some shell scripts to run benchmark with various configurations. @@ -193,40 +119,3 @@ If you run the above command successfully, you will get the following results: year={2023} } ``` - - -# Fine-tune Llama2 - -We also provide a example to fine-tune llama2 in `finetune.py`, - -Make sure master node can access all nodes (including itself) by ssh without password. - -Here is details about CLI arguments: - -- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag. -- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). -- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`. -- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`. -- Number of epochs: `-e`, `--num_epochs`. The default value is 1. -- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. -- Learning rate: `--lr`. The default value is 3e-4. -- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. -- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. -- Max length: `-l`, `--max_length`. The default value is 4096. -- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. -- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. -- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`. -- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. -- Gradient clipping: `--gradient_clipping`. The default value is 1.0. -- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. -- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. - - -```shell -torchrun --standalone --nproc_per_node 8 finetune.py \ - --plugin "hybrid_parallel" \ - --dataset "yizhongw/self_instruct" \ - --model_path "/path/llama" \ - --task_name "super_natural_instructions" \ - --save_dir "/path/output" -``` diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py deleted file mode 120000 index 4e95c7bfa..000000000 --- a/examples/language/llama2/attn.py +++ /dev/null @@ -1 +0,0 @@ -../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py \ No newline at end of file diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index 832465490..ff94891f5 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -3,14 +3,13 @@ import resource from contextlib import nullcontext import torch -from attn import replace_with_flash_attention from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai from colossalai.accelerator import get_accelerator @@ -19,6 +18,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import PipelineGradientCheckpointConfig from examples.language.data_utils import RandomDataset from examples.language.model_utils import format_numel_str, get_model_numel from examples.language.performance_evaluator import PerformanceEvaluator @@ -78,6 +78,7 @@ def main(): parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") + parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) args = parser.parse_args() colossalai.launch_from_torch({}) @@ -86,6 +87,19 @@ def main(): def empty_init(): pass + # ckpt config for LLaMA3-70B on 64 H100 GPUs + ckpt_config = ( + PipelineGradientCheckpointConfig( + num_stages=args.pp, + num_model_chunks=1, + num_model_layers=80, + num_layers_per_stage=[19, 20, 20, 21], + num_ckpt_layers_per_stage=[19, 19, 19, 13], + ) + if args.custom_ckpt + else None + ) + # ============================== # Initialize Booster # ============================== @@ -98,6 +112,8 @@ def main(): offload_param_frac=args.offload_param_frac, tp_size=args.tp, extra_dp_size=args.extra_dp, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -106,26 +122,34 @@ def main(): warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, ) elif args.plugin == "fsdp": if use_empty_init: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ), param_init_fn=empty_init(), ) else: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ) ) elif args.plugin == "fsdp_cpu": if use_empty_init: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ), cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), @@ -133,7 +157,9 @@ def main(): else: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ), cpu_offload=CPUOffload(offload_params=True), ) @@ -141,12 +167,13 @@ def main(): plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, - pp_style="interleaved", zero_stage=args.zero, - num_model_chunks=2, enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", + dp_outside=False, + gradient_checkpoint_config=ckpt_config, ) elif args.plugin == "3d_cpu": plugin = HybridParallelPlugin( @@ -155,6 +182,7 @@ def main(): zero_stage=args.zero, cpu_offload=True, enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", @@ -167,9 +195,12 @@ def main(): # ============================== # Initialize Dataset and Dataloader # ============================== - dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size + dp_size = getattr(plugin, "dp_size", coordinator.world_size) - config = MODEL_CONFIGS[args.config] + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size ) @@ -184,14 +215,17 @@ def main(): else nullcontext() ) + init_kwargs = {} + if config.model_type == "chatglm": + init_kwargs["empty_init"] = False + with init_ctx: - model = LlamaForCausalLM(config) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs) if args.grad_checkpoint: model.gradient_checkpointing_enable() - - if args.xformers: - replace_with_flash_attention(model) + if config.model_type == "chatglm": + model.transformer.encoder.gradient_checkpointing = True model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py deleted file mode 100644 index 69b4ebe42..000000000 --- a/examples/language/llama2/finetune.py +++ /dev/null @@ -1,313 +0,0 @@ -import argparse -import math -import os -import resource -from contextlib import nullcontext -from functools import partial -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from attn import replace_with_flash_attention -from data_utils import load_json, prepare_dataloader, save_json -from datasets import load_dataset -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from transformers.models.llama.tokenization_llama import LlamaTokenizer - -import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam - - -def get_model_numel(model: nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample["prompt"] + sample["completion"] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) - data = {k: v.cuda() for k, v in data.items()} - data["labels"] = data["input_ids"].clone() - return data - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - tensor = tensor.data - tensor.div_(dist.get_world_size()) - return tensor - - -def save( - booster: Booster, - model: nn.Module, - optimizer: Optimizer, - lr_scheduler: _LRScheduler, - epoch: int, - step: int, - batch_size: int, - coordinator: DistCoordinator, - save_dir: str, -): - save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, "model"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - running_states = { - "epoch": epoch, - "step": step, - "sample_start_index": step * batch_size, - } - if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - - -def load( - booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str -) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, "model")) - booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) - running_states = load_json(os.path.join(load_dir, "running_states.json")) - return running_states["epoch"], running_states["step"], running_states["sample_start_index"] - - -def _criterion(outputs, inputs): - return outputs.loss - - -def main(): - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune") - parser.add_argument( - "-p", - "--plugin", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], - default="gemini", - help="Choose which plugin to use", - ) - parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path") - parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run") - parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") - parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") - parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") - parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") - parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") - parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") - parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") - parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") - args = parser.parse_args() - - # ============================== - # Initialize Distributed Training - # ============================== - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - # ============================== - # Initialize Booster - # ============================== - if args.plugin == "gemini": - plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2_cpu": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip - ) - elif args.plugin == "hybrid_parallel": - # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin( - tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision="fp32", - initial_scale=1, - ) - else: - raise ValueError(f"Unknown plugin {args.plugin}") - - booster = Booster(plugin=plugin) - - use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) - - # ============================== - # Initialize Tensorboard - # ============================== - if print_flag: - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - - # ============================== - # Initialize Model, Optimizer and LR Scheduler - # ============================== - - config = LlamaConfig.from_pretrained(args.model_path) - # use lazy init when using GeminiPlugin - init_ctx = ( - LazyInitContext(default_device=get_accelerator().get_current_device()) - if isinstance(plugin, GeminiPlugin) - else nullcontext() - ) - - with init_ctx: - model = LlamaForCausalLM(config) - - # ============================== - # Initialize Tokenizer, Dataset and Dataloader - # ============================== - tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 - tokenizer.pad_token = tokenizer.unk_token - - dataset = load_dataset(args.dataset, args.task_name) - train_ds = dataset["train"] - dataloader = prepare_dataloader( - train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length), - ) - - if args.grad_checkpoint: - model.gradient_checkpointing_enable() - if args.flash_attention: - replace_with_flash_attention(model) - - model_numel = get_model_numel(model) - coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") - - optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) - total_step = args.num_epochs * len(dataloader) - lr_scheduler = CosineAnnealingWarmupLR( - optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr - ) - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler - ) - torch.set_default_dtype(torch.float) - - booster.load_model(model, args.model_path) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - coordinator.print_on_master( - f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" - ) - - # load checkpoint if specified - start_epoch = 0 - start_step = 0 - sampler_start_idx = 0 - if args.load is not None: - coordinator.print_on_master("Loading checkpoint") - start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") - - num_steps_per_epoch = len(dataloader) - - # if resume training, set the sampler start index to the correct value - dataloader.sampler.set_start_index(sampler_start_idx) - for epoch in range(start_epoch, args.num_epochs): - dataloader.sampler.set_epoch(epoch) - step_nums = num_steps_per_epoch - start_step - dataloader_iter = iter(dataloader) - - with tqdm( - range(step_nums), - desc=f"Epoch {epoch}", - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step, - ) as pbar: - for step in pbar: - if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) - loss = outputs["loss"] - else: - batch = next(dataloader_iter) - outputs = model(**batch) - loss = outputs[0] - booster.backward(loss, optimizer) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - if not use_pipeline: - all_reduce_mean(loss) - if print_flag: - pbar.set_postfix({"loss": loss.item()}) - writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) - - if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f"Saving checkpoint") - save( - booster, - model, - optimizer, - lr_scheduler, - epoch, - step + 1, - args.batch_size, - coordinator, - args.save_dir, - ) - coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(0) - start_step = 0 - - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - - -if __name__ == "__main__": - main() diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py deleted file mode 100644 index 970cd5290..000000000 --- a/examples/language/llama2/pretrain.py +++ /dev/null @@ -1,328 +0,0 @@ -import argparse -import os -import resource -from contextlib import nullcontext -from functools import partial -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from attn import replace_with_flash_attention -from data_utils import load_json, prepare_dataloader, save_json -from datasets import load_dataset -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from transformers.models.llama.tokenization_llama import LlamaTokenizer - -import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam - -MODEL_CONFIGS = { - "7b": LlamaConfig(max_position_embeddings=4096), - "13b": LlamaConfig( - hidden_size=5120, - intermediate_size=13824, - num_hidden_layers=40, - num_attention_heads=40, - max_position_embeddings=4096, - ), - "70b": LlamaConfig( - hidden_size=8192, - intermediate_size=28672, - num_hidden_layers=80, - num_attention_heads=64, - max_position_embeddings=4096, - num_key_value_heads=8, - ), -} - - -def get_model_numel(model: nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample["text"] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) - data = {k: v.cuda() for k, v in data.items()} - data["labels"] = data["input_ids"].clone() - return data - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - tensor = tensor.data - tensor.div_(dist.get_world_size()) - return tensor - - -def save( - booster: Booster, - model: nn.Module, - optimizer: Optimizer, - lr_scheduler: _LRScheduler, - epoch: int, - step: int, - batch_size: int, - coordinator: DistCoordinator, - save_dir: str, -): - save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, "model"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - running_states = { - "epoch": epoch, - "step": step, - "sample_start_index": step * batch_size, - } - if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - - -def load( - booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str -) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, "model")) - booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) - running_states = load_json(os.path.join(load_dir, "running_states.json")) - return running_states["epoch"], running_states["step"], running_states["sample_start_index"] - - -def _criterion(outputs, inputs): - return outputs.loss - - -def main(): - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") - parser.add_argument( - "-p", - "--plugin", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], - default="gemini", - help="Choose which plugin to use", - ) - parser.add_argument( - "-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path" - ) - parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") - parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps") - parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") - parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") - parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") - parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") - parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") - parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") - parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") - args = parser.parse_args() - - # ============================== - # Initialize Distributed Training - # ============================== - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - # ============================== - # Initialize Booster - # ============================== - if args.plugin == "gemini": - plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2_cpu": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip - ) - elif args.plugin == "hybrid_parallel": - # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin( - tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision=args.mixed_precision, - initial_scale=1, - ) - else: - raise ValueError(f"Unknown plugin {args.plugin}") - - booster = Booster(plugin=plugin) - - use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) - - # ============================== - # Initialize Tensorboard - # ============================== - if print_flag: - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - - # ============================== - # Initialize Tokenizer, Dataset and Dataloader - # ============================== - tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 - tokenizer.pad_token = tokenizer.unk_token - - dataset = load_dataset(args.dataset) - train_ds = dataset["train"] - dataloader = prepare_dataloader( - train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length), - ) - - # ============================== - # Initialize Model, Optimizer and LR Scheduler - # ============================== - config = MODEL_CONFIGS[args.config] - # use lazy init when using GeminiPlugin - init_ctx = ( - LazyInitContext(default_device=get_accelerator().get_current_device()) - if isinstance(plugin, GeminiPlugin) - else nullcontext() - ) - - with init_ctx: - model = LlamaForCausalLM(config) - - if args.grad_checkpoint: - model.gradient_checkpointing_enable() - if args.flash_attention: - replace_with_flash_attention(model) - - model_numel = get_model_numel(model) - coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") - - optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) - lr_scheduler = CosineAnnealingWarmupLR( - optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr - ) - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler - ) - torch.set_default_dtype(torch.float) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - coordinator.print_on_master( - f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" - ) - - # load checkpoint if specified - start_epoch = 0 - start_step = 0 - sampler_start_idx = 0 - if args.load is not None: - coordinator.print_on_master("Loading checkpoint") - start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") - - num_steps_per_epoch = len(dataloader) - - # if resume training, set the sampler start index to the correct value - dataloader.sampler.set_start_index(sampler_start_idx) - for epoch in range(start_epoch, args.num_epochs): - dataloader.sampler.set_epoch(epoch) - dataloader_iter = iter(dataloader) - - with tqdm( - range(start_step, num_steps_per_epoch), - desc=f"Epoch {epoch}", - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step, - ) as pbar: - for step in pbar: - if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) - loss = outputs["loss"] - else: - batch = next(dataloader_iter) - outputs = model(**batch) - loss = outputs[0] - booster.backward(loss, optimizer) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - if not use_pipeline: - all_reduce_mean(loss) - if print_flag: - pbar.set_postfix({"loss": loss.item()}) - writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) - - if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f"Saving checkpoint") - save( - booster, - model, - optimizer, - lr_scheduler, - epoch, - step + 1, - args.batch_size, - coordinator, - args.save_dir, - ) - coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(0) - start_step = 0 - - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - - -if __name__ == "__main__": - main() diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama2/requirements.txt index 6b475682d..438a4999a 100644 --- a/examples/language/llama2/requirements.txt +++ b/examples/language/llama2/requirements.txt @@ -1,9 +1,8 @@ -colossalai>=0.3.2 +colossalai>=0.3.6 datasets numpy -torch>=1.12.0,<=2.0.0 tqdm transformers -flash-attn>=2.0.0,<=2.0.5 +flash-attn>=2.0.0 SentencePiece==0.1.99 tensorboard==2.14.0 From f4c5aafe2987c0c78c4989c178874dc8c1cc7369 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Tue, 23 Apr 2024 18:48:07 +0800 Subject: [PATCH 08/28] [example] llama3 (#5631) * release llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [release] llama3 --- README.md | 14 +++++++++++--- docs/README-zh-Hans.md | 10 +++++++++- examples/language/{llama2 => llama}/README.md | 6 ++++++ examples/language/{llama2 => llama}/benchmark.py | 0 .../language/{llama2 => llama}/requirements.txt | 0 .../{llama2 => llama}/scripts/benchmark_70B/3d.sh | 0 .../scripts/benchmark_70B/gemini.sh | 0 .../scripts/benchmark_70B/gemini_auto.sh | 0 .../scripts/benchmark_7B/gemini.sh | 0 .../scripts/benchmark_7B/gemini_auto.sh | 0 examples/language/{llama2 => llama}/test_ci.sh | 0 11 files changed, 26 insertions(+), 4 deletions(-) rename examples/language/{llama2 => llama}/README.md (94%) rename examples/language/{llama2 => llama}/benchmark.py (100%) rename examples/language/{llama2 => llama}/requirements.txt (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_70B/3d.sh (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_70B/gemini.sh (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_70B/gemini_auto.sh (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_7B/gemini.sh (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_7B/gemini_auto.sh (100%) rename examples/language/{llama2 => llama}/test_ci.sh (100%) diff --git a/README.md b/README.md index 26776bdf6..c1e2da0d4 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@

  • Parallel Training Demo
      -
    • LLaMA 1/2
    • +
    • LLaMA 1/2/3
    • MoE
    • GPT-3
    • GPT-2
    • @@ -270,13 +270,21 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)

      (back to top)

      ## Parallel Training Demo +### LLaMA3 +

      + +

      + +- 70 billion parameter LLaMA3 model training accelerated by 18% +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) + ### LLaMA2

      - 70 billion parameter LLaMA2 model training accelerated by 195% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) [[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) ### LLaMA1 @@ -285,7 +293,7 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)

      - 65-billion-parameter large model pretraining accelerated by 38% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama) +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) [[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining) ### MoE diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 6d243a808..7e0ed07fe 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -51,7 +51,7 @@
    • 并行训练样例展示
        -
      • LLaMA 1/2
      • +
      • LLaMA 1/2/3
      • MoE
      • GPT-3
      • GPT-2
      • @@ -261,6 +261,14 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

        (返回顶端)

        ## 并行训练样例展示 +### LLaMA3 +

        + +

        + +- 700亿参数LLaMA3训练加速18% +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) + ### LLaMA2

        diff --git a/examples/language/llama2/README.md b/examples/language/llama/README.md similarity index 94% rename from examples/language/llama2/README.md rename to examples/language/llama/README.md index 11b2ee511..fa0c6dc07 100644 --- a/examples/language/llama2/README.md +++ b/examples/language/llama/README.md @@ -1,4 +1,10 @@ # Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models +### LLaMA3 +

        + +

        + +- 70 billion parameter LLaMA3 model training accelerated by 18% ### LLaMA2

        diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama/benchmark.py similarity index 100% rename from examples/language/llama2/benchmark.py rename to examples/language/llama/benchmark.py diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama/requirements.txt similarity index 100% rename from examples/language/llama2/requirements.txt rename to examples/language/llama/requirements.txt diff --git a/examples/language/llama2/scripts/benchmark_70B/3d.sh b/examples/language/llama/scripts/benchmark_70B/3d.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_70B/3d.sh rename to examples/language/llama/scripts/benchmark_70B/3d.sh diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini.sh b/examples/language/llama/scripts/benchmark_70B/gemini.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_70B/gemini.sh rename to examples/language/llama/scripts/benchmark_70B/gemini.sh diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh b/examples/language/llama/scripts/benchmark_70B/gemini_auto.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh rename to examples/language/llama/scripts/benchmark_70B/gemini_auto.sh diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini.sh b/examples/language/llama/scripts/benchmark_7B/gemini.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_7B/gemini.sh rename to examples/language/llama/scripts/benchmark_7B/gemini.sh diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh b/examples/language/llama/scripts/benchmark_7B/gemini_auto.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh rename to examples/language/llama/scripts/benchmark_7B/gemini_auto.sh diff --git a/examples/language/llama2/test_ci.sh b/examples/language/llama/test_ci.sh similarity index 100% rename from examples/language/llama2/test_ci.sh rename to examples/language/llama/test_ci.sh From 0d0a5820331768896a7bd743cf6407586dd534ca Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 24 Apr 2024 22:51:50 +0800 Subject: [PATCH 09/28] [shardformer] update transformers (#5583) * flash_attention forward upgrade * llama_model_forward * remove useless comment * update the requirements.txt * add the transformers version requirements * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update_falcon (#5520) * [shardformer] update mistral model (#5511) * [shardformer] update gpt2 (#5502) * [shardformer] update gptj model (#5503) * [shardformer] update opt (#5522) * [shardformer] update t5 model (#5524) * [shardformer] update whisper model (#5529) * [shardformer] update vit model (#5530) * update vit model * remove the output_hidden_states * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * fix conflicts * [doc] fix ColossalMoE readme (#5599) * fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * merge with main * merge with main * llama_model_forward * remove useless comment * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update mistral model (#5511) * [shardformer] update opt (#5522) * [shardformer] update whisper model (#5529) * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606) * fix no pad token bug * fixed some auto parallel codegen bug, but might not run on torch 2.1 --------- Co-authored-by: Edenzzzz * [shardformer] fix pipeline grad ckpt (#5620) * [shardformer] fix pipeline grad ckpt * [shardformer] fix whisper (#5628) * [test] fix llama model test * fix the opt upgrade (#5634) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * Fix shardformer upgrade (#5640) * fix llama model * fix the mistral * fix the shardformer model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]support pipeline parallelism for mistral. (#5642) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Support LLaMA-3 CPT and ST (#5619) * support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [exampe] update llama example (#5626) * [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readme * [example] llama3 (#5631) * release llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * support pp for mistral * fix * fix fix fix * fix --------- Co-authored-by: Hongxin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tong Li Co-authored-by: binmakeswell --------- Co-authored-by: Hongxin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Co-authored-by: Edenzzzz Co-authored-by: Edenzzzz Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: Tong Li Co-authored-by: binmakeswell --- colossalai/shardformer/modeling/bloom.py | 38 +- colossalai/shardformer/modeling/falcon.py | 229 +++---- colossalai/shardformer/modeling/gpt2.py | 20 +- colossalai/shardformer/modeling/gptj.py | 24 +- colossalai/shardformer/modeling/llama.py | 148 +++-- colossalai/shardformer/modeling/mistral.py | 614 ++++++++++++++++-- colossalai/shardformer/modeling/opt.py | 58 +- colossalai/shardformer/modeling/t5.py | 25 +- colossalai/shardformer/modeling/vit.py | 18 +- colossalai/shardformer/modeling/whisper.py | 65 +- colossalai/shardformer/policies/bloom.py | 6 - colossalai/shardformer/policies/falcon.py | 22 +- colossalai/shardformer/policies/gpt2.py | 9 +- colossalai/shardformer/policies/gptj.py | 9 +- colossalai/shardformer/policies/llama.py | 23 +- colossalai/shardformer/policies/mistral.py | 178 ++++- colossalai/shardformer/policies/opt.py | 21 +- colossalai/shardformer/policies/sam.py | 34 +- colossalai/shardformer/policies/whisper.py | 23 +- colossalai/zero/gemini/gemini_ddp.py | 1 + requirements/requirements-test.txt | 1 - requirements/requirements.txt | 1 + tests/kit/model_zoo/transformers/llama.py | 1 - tests/kit/model_zoo/transformers/mistral.py | 3 + .../test_model/test_shard_llama.py | 2 +- .../test_model/test_shard_mistral.py | 21 +- .../test_model/test_shard_whisper.py | 2 +- 27 files changed, 1155 insertions(+), 441 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index fe70376e1..c4f326364 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -6,6 +6,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -205,12 +206,13 @@ class BloomPipelineForwards: alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = self._prepare_attn_mask( + causal_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) - + causal_mask = causal_mask.bool() # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config and shard_config.enable_sequence_parallelism: @@ -227,21 +229,15 @@ class BloomPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( @@ -1002,11 +998,13 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - causal_mask = self._prepare_attn_mask( + causal_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) + causal_mask = causal_mask.bool() # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( @@ -1018,21 +1016,15 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 4e271dfe0..df3b09c71 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,9 +1,16 @@ +import math +import warnings from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -99,11 +106,17 @@ def get_tp_falcon_decoder_layer_forward(): hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) residual = hidden_states if self.config.new_decoder_architecture: @@ -117,10 +130,12 @@ def get_tp_falcon_decoder_layer_forward(): attention_layernorm_out, layer_past=layer_past, attention_mask=attention_mask, + position_ids=position_ids, alibi=alibi, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + **kwargs, ) attention_output = attn_outputs[0] @@ -154,87 +169,6 @@ def get_tp_falcon_decoder_layer_forward(): return forward -def get_falcon_flash_attention_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - from transformers.models.falcon.modeling_falcon import FalconAttention - - def forward( - self: FalconAttention, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, query_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape( - batch_size * num_kv_heads, - query_length, - self.head_dim, - ) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) - - past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] - query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) - - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, kv_length, head_dim] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) - - _, kv_length, _ = key_layer.shape - if use_cache: - present = (key_layer, value_layer) - else: - present = None - - attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) - - query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous() - key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() - value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() - - if alibi is not None: - attention_mask_float = ( - attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta - ) - - batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1] - tgt_len = key_layer_.size()[1] - attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous() - context_layer = me_attention( - query_layer_, - key_layer_, - value_layer_, - attn_bias=attention_mask_float, - scale=self.inv_norm_factor, - p=self.attention_dropout.p, - ) - batch_size, seq_length, _, _ = context_layer.shape - context_layer = context_layer.reshape(batch_size, seq_length, -1) - - output_tensor = self.dense(context_layer) - - return output_tensor, present - - return forward - - class FalconPipelineForwards: """ This class serves as a micro library for falcon pipeline forwards. @@ -246,6 +180,7 @@ class FalconPipelineForwards: input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -274,17 +209,6 @@ class FalconPipelineForwards: return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - else: - past_key_values = self._convert_to_rw_cache(past_key_values) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - # case: First stage of training if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -295,16 +219,22 @@ class FalconPipelineForwards: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - hidden_states = inputs_embeds - else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -312,22 +242,80 @@ class FalconPipelineForwards: # Compute alibi tensor: check build_alibi_tensor documentation past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) + past_key_values_length = past_key_values[0][0].shape[-2] if self.use_alibi: - alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + mask = ( + torch.ones( + (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long + ) + if attention_mask is None + else attention_mask + ) + alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) else: alibi = None + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + if alibi is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + elif head_mask is None: + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + + attention_mask_2d = attention_mask + # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # We take care to integrate alibi bias in the attention_mask here. + if attention_mask_2d is None: + attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) + else: + attention_mask = torch.masked_fill( + alibi / math.sqrt(self.config.hidden_size // self.num_heads), + attention_mask < -1, + torch.finfo(alibi.dtype).min, + ) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if seq_length > 1: + attention_mask = AttentionMaskConverter._unmask_unattended( + attention_mask, attention_mask_2d, unmasked_value=0.0 + ) + else: + # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) start_idx, end_idx = stage_index[0], stage_index[1] for i, (block, layer_past) in enumerate( @@ -337,31 +325,23 @@ class FalconPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, - causal_mask, + attention_mask, + position_ids, head_mask[i], + layer_past, + use_cache, + output_attentions, ) else: outputs = block( hidden_states, layer_past=layer_past, - attention_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, @@ -382,9 +362,6 @@ class FalconPipelineForwards: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if presents is not None: - presents = self._convert_cache_to_standard_format(presents, batch_size) - if stage_manager.is_last_stage(): if not return_dict: return tuple( diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 26088569a..17acdf7fc 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -177,11 +177,9 @@ class GPT2PipelineForwards: head_mask = self.get_head_mask(head_mask, self.config.n_layer) if stage_manager.is_first_stage(): - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - else: + if position_ids is None: position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -239,22 +237,16 @@ class GPT2PipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 5c254d1e7..4f4cec8bc 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -148,11 +148,9 @@ class GPTJPipelineForwards: head_mask = self.get_head_mask(head_mask, self.config.n_layer) # position id to be assigned not just for the first stage for attn input - if position_ids is not None: - position_ids = position_ids.view(-1, seq_length) - else: + if position_ids is None: position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + position_ids = position_ids.unsqueeze(0) if stage_manager.is_first_stage(): if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -201,21 +199,15 @@ class GPTJPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( @@ -627,7 +619,9 @@ def get_gptj_flash_attention_forward(): value = torch.cat((past_value, value), dim=-2) if use_cache is True: - present = (key, value) + # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation. + # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128 + present = (key.to(hidden_states.dtype), value) else: present = None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index c3b5426c2..0eb08a043 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,6 +7,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -16,6 +17,8 @@ from transformers.models.llama.modeling_llama import ( LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, apply_rotary_pos_emb, repeat_kv, ) @@ -31,13 +34,6 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d -try: - from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask - - LATEST_VERSION = True -except ImportError: - LATEST_VERSION = False - class LlamaPipelineForwards: """ @@ -75,13 +71,13 @@ class LlamaPipelineForwards: # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - batch_size, seq_length = input_ids.shape + batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape + batch_size, seq_length, _ = inputs_embeds.shape[:2] else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -111,11 +107,12 @@ class LlamaPipelineForwards: if position_ids is None: position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + position_ids = position_ids.unsqueeze(0) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage @@ -123,20 +120,32 @@ class LlamaPipelineForwards: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, ) else: - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device - ) - if LATEST_VERSION: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) else: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, ) if self.gradient_checkpointing and self.training: @@ -149,7 +158,7 @@ class LlamaPipelineForwards: # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None start_idx, end_idx = stage_index[0], stage_index[1] num_ckpt_layers = 0 @@ -160,7 +169,7 @@ class LlamaPipelineForwards: num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, num_layers=end_idx - start_idx, - model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), ) assert num_ckpt_layers <= end_idx - start_idx @@ -168,30 +177,22 @@ class LlamaPipelineForwards: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if idx - start_idx < num_ckpt_layers: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, - None, + past_key_values, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -199,7 +200,7 @@ class LlamaPipelineForwards: hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -212,7 +213,16 @@ class LlamaPipelineForwards: next_cache = next_decoder_cache if use_cache else None if stage_manager.is_last_stage(): if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -458,23 +468,25 @@ class LlamaPipelineForwards: def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - llama_version = 2 try: from transformers.models.llama.modeling_llama import repeat_kv except: warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") - llama_version = 1 def forward( self: LlamaAttention, hidden_states: torch.Tensor, attention_mask: Optional[dict] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) bsz, q_len, _ = hidden_states.size() if sp_mode in ["split_gather", "ring"]: @@ -498,21 +510,23 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - if llama_version == 2: - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) @@ -573,7 +587,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -587,7 +604,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, ) if self.gradient_checkpointing and self.training: @@ -918,7 +939,10 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -934,10 +958,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, ) - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length ) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 0da1a35a0..ac7845400 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -1,70 +1,606 @@ -from typing import Optional, Tuple +import warnings +from typing import List, Optional, Tuple, Union import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.shard import ShardConfig + +from ..layer import ColoAttention + +logger = logging.get_logger(__name__) -def get_mistral_flash_attention_forward(): +class MistralForwards: + @staticmethod + def mistral_model_forward( + self: MistralModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if use_cache: + logger.warning_once("use_cache=True is not supported for Mistral models at the moment.") + use_cache = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + past_key_values_length = 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length, seq_length) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + ) + assert num_ckpt_layers <= end_idx - start_idx + + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if idx - start_idx < num_ckpt_layers: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + else: + return {"hidden_states": hidden_states} + + @staticmethod + def mistral_for_causal_lm_forward( + self: MistralForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = MistralForwards.mistral_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def mistral_for_sequence_classification_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = MistralForwards.mistral_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig): + logger = logging.get_logger(__name__) + assert shard_config.enable_flash_attention, "Flash Attention is not enabled." + + def forward( + self: MistralModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length, seq_length) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward + + +def get_mistral_flash_attention_forward(shard_config: ShardConfig): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - def forward( self: MistralAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) bsz, q_len, _ = hidden_states.size() - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = ( - self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - ) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) - key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) - value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if attention_mask != None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal - - attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type - ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index a26526430..8f841c8a6 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -42,7 +43,7 @@ def _get_attention_mask( is_causal=True, ) else: - attention_mask = self.decoder._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), hidden_states, @@ -57,6 +58,20 @@ class OPTPipelineForwards: under pipeline setting. """ + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + @staticmethod def opt_model_forward( self: OPTModel, @@ -112,7 +127,7 @@ class OPTPipelineForwards: inputs_embeds = decoder.project_in(inputs_embeds) device = input_ids.device if input_ids is not None else inputs_embeds.device inputs_embeds.dtype - + hidden_states = inputs_embeds else: if hidden_states is None: raise ValueError("hidden_states shouldn't be None for intermediate stages.") @@ -125,12 +140,25 @@ class OPTPipelineForwards: # required mask seq length can be calculated via length of past mask_seq_length = past_key_values_length + seq_length # embed positions - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - elif attention_mask.shape[1] != mask_seq_length: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)" + if self.decoder._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length ) if stage_manager.is_first_stage(): @@ -205,20 +233,14 @@ class OPTPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if decoder.gradient_checkpointing and decoder.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 9c5ce3fb6..b35bb6b94 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -118,16 +117,13 @@ class T5PipelineForwards: # required mask seq length can be calculated via length of past mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long) - # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=device) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) @@ -138,7 +134,7 @@ class T5PipelineForwards: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -162,15 +158,8 @@ class T5PipelineForwards: torch.cuda.set_device(hidden_states.device) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -180,6 +169,8 @@ class T5PipelineForwards: layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index e9c256a13..67b10988d 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -14,6 +14,8 @@ def _encoder_forward( end_idx: int, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, return_dict: bool = True, stage_manager: PipelineStageManager = None, ) -> Union[tuple, BaseModelOutput]: @@ -23,20 +25,14 @@ def _encoder_forward( layer_head_mask = head_mask[i] if head_mask is not None else None if encoder.gradient_checkpointing and encoder.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, False) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = encoder._gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: - layer_outputs = layer_module(hidden_states, layer_head_mask, False) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if not stage_manager.is_last_stage(): @@ -114,6 +110,8 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: end_idx=stage_index[1], hidden_states=hidden_states, head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, stage_manager=stage_manager, ) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 7ccc79276..6d7df963a 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -5,6 +5,10 @@ from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -35,6 +39,8 @@ def _get_attention_mask( hidden_states: torch.Tensor, past_key_values_length: int, attention_mask: Optional[torch.FloatTensor], + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ): batch_size, seq_length = hidden_states.shape[:2] mask_seq_length = past_key_values_length + seq_length @@ -47,12 +53,20 @@ def _get_attention_mask( is_causal=True, ) else: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - hidden_states, - past_key_values_length, - ) + input_shape = (batch_size, seq_length) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, input_shape, hidden_states, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length + ) return attention_mask @@ -539,18 +553,12 @@ class WhisperPipelineForwards: layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -702,20 +710,16 @@ class WhisperPipelineForwards: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + attention_mask = _get_attention_mask( + self, shard_config, inputs_embeds, past_key_values_length, attention_mask + ) + # embed positions if input_ids is not None: positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) - attention_mask = _get_attention_mask( - self, - shard_config, - inputs_embeds, - past_key_values_length, - attention_mask, - ) - hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -732,7 +736,6 @@ class WhisperPipelineForwards: "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." ) input_shape = hidden_states.size()[:-1] - attention_mask = _get_attention_mask( self, shard_config, @@ -756,16 +759,8 @@ class WhisperPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -773,6 +768,8 @@ class WhisperPipelineForwards: head_mask[idx] if head_mask is not None else None, (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), None, # past_key_value + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 953592abc..4894bda35 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -24,12 +24,6 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe class BloomPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Bloom model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index a2f110a41..e72a97e4b 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -7,12 +7,7 @@ from torch.nn import Module import colossalai.shardformer.layer as col_nn -from ..modeling.falcon import ( - FalconPipelineForwards, - build_falcon_alibi_tensor_fn, - get_falcon_flash_attention_forward, - get_tp_falcon_decoder_layer_forward, -) +from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["FalconPolicy"] @@ -21,12 +16,6 @@ __all__ = ["FalconPolicy"] class FalconPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Falcon model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass @@ -36,7 +25,7 @@ class FalconPolicy(Policy): return self.model def module_policy(self): - from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel + from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel if not self.model.config.new_decoder_architecture and self.model.config.multi_query: warnings.warn( @@ -147,11 +136,8 @@ class FalconPolicy(Policy): ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={"forward": get_falcon_flash_attention_forward()}, - policy=policy, - target_key=FalconAttention, - ) + warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.") + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 98db7b948..6f4f835a8 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -35,13 +35,20 @@ class GPT2Policy(Policy): Reshape the Embedding layer to make the embedding dimension divisible by world_size """ self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model + ATTN_IMPLEMENTATION = { + "eager": GPT2Attention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -186,7 +193,7 @@ class GPT2Policy(Policy): "forward": get_gpt2_flash_attention_forward(), }, policy=policy, - target_key=GPT2Attention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: policy[GPT2Model].method_replacement = { diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 4b69137a6..1280efaec 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -30,13 +30,20 @@ class GPTJPolicy(Policy): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel + ATTN_IMPLEMENTATION = { + "eager": GPTJAttention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -160,7 +167,7 @@ class GPTJPolicy(Policy): "forward": get_gptj_flash_attention_forward(), }, policy=policy, - target_key=GPTJAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ff686a179..0a95284bc 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -36,13 +36,26 @@ class LlamaPolicy(Policy): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaFlashAttention2, + LlamaModel, + LlamaSdpaAttention, + ) + ATTN_IMPLEMENTATION = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -93,7 +106,7 @@ class LlamaPolicy(Policy): "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) elif sp_mode == "all_to_all": decoder_attribute_replacement = { @@ -102,7 +115,7 @@ class LlamaPolicy(Policy): if getattr(self.model.config, "num_key_value_heads", False): decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size - policy[LlamaAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) self.append_or_create_method_replacement( @@ -110,7 +123,7 @@ class LlamaPolicy(Policy): "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) self.append_or_create_method_replacement( description={ @@ -221,7 +234,7 @@ class LlamaPolicy(Policy): "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) if self.pipeline_stage_manager is None: # replace llama model forward method diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index b225fd2a9..b5018e47d 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,7 +1,10 @@ import warnings -from typing import Dict, Union +from functools import partial +from typing import Callable, Dict, List, Union import torch.nn as nn +from torch import Tensor +from torch.nn import Module from colossalai.shardformer.layer import ( FusedRMSNorm, @@ -13,7 +16,11 @@ from colossalai.shardformer.layer import ( VocabParallelLMHead1D, ) -from ..modeling.mistral import get_mistral_flash_attention_forward +from ..modeling.mistral import ( + MistralForwards, + get_mistral_flash_attention_forward, + get_mistral_model_forward_for_flash_attn, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] @@ -25,13 +32,26 @@ class MistralPolicy(Policy): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel + from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralFlashAttention2, + MistralModel, + ) + + ATTN_IMPLEMENTATION = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -127,27 +147,112 @@ class MistralPolicy(Policy): if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_mistral_flash_attention_forward(), + "forward": get_mistral_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=MistralAttention, + target_key=attn_cls, ) + if self.pipeline_stage_manager is None: + # replace llama model forward method + self.append_or_create_method_replacement( + description={ + "forward": get_mistral_model_forward_for_flash_attn(self.shard_config), + }, + policy=policy, + target_key=MistralModel, + ) return policy def postprocess(self): return self.model + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "MistralModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "MistralModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + return held_layers + class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: super().__init__() def module_policy(self): - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") + policy = super().module_policy() + from transformers.models.mistral.modeling_mistral import MistralModel - return super().module_policy() + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in mistral model""" + return [] class MistralForCausalLMPolicy(MistralPolicy): @@ -155,8 +260,6 @@ class MistralForCausalLMPolicy(MistralPolicy): from transformers import MistralForCausalLM policy = super().module_policy() - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -189,8 +292,38 @@ class MistralForCausalLMPolicy(MistralPolicy): policy.update(new_item) + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MistralForCausalLM, new_forward=MistralForwards.mistral_for_causal_lm_forward, policy=policy + ) + return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + mistral_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(mistral_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: mistral_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + class MistralForSequenceClassificationPolicy(MistralPolicy): def module_policy(self): @@ -209,9 +342,26 @@ class MistralForSequenceClassificationPolicy(MistralPolicy): ] ) } - - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") - policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MistralForSequenceClassification, + new_forward=MistralForwards.mistral_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index ac78ff6a7..2f6eabd5f 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -38,26 +38,27 @@ __all__ = [ class OPTPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - # TODO: remove this version check when transformers>=4.36.0 - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The OPT model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): - from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer, OptFlashAttention2 + + ATTN_IMPLEMENTATION = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -88,7 +89,7 @@ class OPTPolicy(Policy): ] ) - policy[OPTAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement={ "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, @@ -158,7 +159,7 @@ class OPTPolicy(Policy): "forward": get_opt_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=OPTAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 498e62164..ce33925ff 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -1,6 +1,8 @@ +import warnings + import colossalai.shardformer.layer as col_nn -from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward +from ..modeling.sam import forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["SamPolicy", "SamModelPolicy"] @@ -15,7 +17,6 @@ class SamPolicy(Policy): def module_policy(self): from transformers.models.sam.modeling_sam import ( - SamAttention, SamTwoWayAttentionBlock, SamTwoWayTransformer, SamVisionAttention, @@ -210,20 +211,21 @@ class SamPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_sam_flash_attention_forward(), - }, - policy=policy, - target_key=SamAttention, - ) - self.append_or_create_method_replacement( - description={ - "forward": get_sam_vision_flash_attention_forward(), - }, - policy=policy, - target_key=SamVisionAttention, - ) + warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.") + # self.append_or_create_method_replacement( + # description={ + # "forward": get_sam_flash_attention_forward(), + # }, + # policy=policy, + # target_key=SamAttention, + # ) + # self.append_or_create_method_replacement( + # description={ + # "forward": get_sam_vision_flash_attention_forward(), + # }, + # policy=policy, + # target_key=SamVisionAttention, + # ) return policy diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 0b5114fa6..aeb668797 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -29,13 +29,6 @@ __all__ = [ class WhisperPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - # TODO: remove this version check when transformers>=4.36.0 - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Whisper model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass @@ -55,6 +48,8 @@ class WhisperPolicy(Policy): WhisperDecoderLayer, WhisperEncoder, WhisperEncoderLayer, + WhisperFlashAttention2, + WhisperSdpaAttention, ) policy = {} @@ -249,6 +244,20 @@ class WhisperPolicy(Policy): policy=policy, target_key=WhisperAttention, ) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperFlashAttention2, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperSdpaAttention, + ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( description={ diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c79422171..b25de1d68 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -840,6 +840,7 @@ class GeminiDDP(ModelWrapper): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() + for buffer in self.module.buffers(): buffer.data = buffer.to(get_accelerator().get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 4136cefc3..0b15b9311 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -3,7 +3,6 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers==4.33.0 timm titans torchaudio diff --git a/requirements/requirements.txt b/requirements/requirements.txt index fd97f5c5a..d307312de 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,3 +16,4 @@ ray sentencepiece google protobuf +transformers==4.36.2 diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 58b5b0487..61fa56050 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -64,7 +64,6 @@ if HAS_LLAMA: intermediate_size=64, num_attention_heads=4, max_position_embeddings=128, - num_labels=16, ) if hasattr(config, "pad_token_id"): diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index 37f875857..ae5a97002 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -52,6 +52,9 @@ config = MistralConfig( hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258 ) +if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + model_zoo.register( name="transformers_mistral", model_fn=lambda: transformers.MistralModel(config), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 27f904292..2a10d86c7 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -32,7 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, model_fn, loss_fn, test_config ) if enable_gradient_checkpointing: - org_model.gradient_checkpointing_enable() + # org_model.gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable() org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index 07bc91b33..05c199814 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -91,7 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config["precision"] == "fp32": - atol, rtol = 1e-4, 1e-3 + atol, rtol = 2e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 check_weight( @@ -114,6 +114,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, @@ -156,7 +174,6 @@ def check_mistral(rank, world_size, port): run_mistral_test() -@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6efb8a922..af61e4640 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 2, "enable_metadata_cache": False, "enable_all_optimization": True, - "use_lazy_init": True, + "use_lazy_init": False, "precision": "fp32", "initial_scale": 1, }, From 7ee569b05fcc94b2f286567ae79f3c338db1a508 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 24 Apr 2024 23:04:06 +0800 Subject: [PATCH 10/28] [hotfix] Fixed fused layernorm bug without apex (#5609) * fixed fused layernorm bug without apex * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * same for flash attn * remove flash attn check --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/shardformer/layer/normalization.py | 8 +++++++- colossalai/shardformer/shard/shard_config.py | 10 +++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 43dd153af..bba4bd070 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm): # fall back to the normal fused layernorm is not built ApexFusedLayerNorm = FusedLayerNormWithHook else: - ApexFusedLayerNorm = FusedLayerNormWithHook + try: + ApexFusedLayerNorm = FusedLayerNormWithHook + except NameError: + warnings.warn( + "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead." + ) + return module layernorm = ( ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 597dd9c26..e20b8e239 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -120,7 +120,15 @@ class ShardConfig: Turn on all optimization. """ # you can add all the optimization flag here - self.enable_fused_normalization = True + try: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm # noqa + + apex_avail = True + except ImportError: + apex_avail = False + warnings.warn("You set enable_all_optimization=True, but apex is not installed.") + + self.enable_fused_normalization = apex_avail self.enable_flash_attention = True self.enable_jit_fused = True # This can cause non-in-place param sharding when used without ZeRO. From 148506c828fefe5da60d89dd4ae993abeff9c78a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 25 Apr 2024 10:47:14 +0800 Subject: [PATCH 11/28] [coloattention]modify coloattention (#5627) * modify coloattention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fxi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/kernel/kernel_loader.py | 4 ---- colossalai/shardformer/layer/attn.py | 23 ++++++++++++------- .../test_shardformer/test_flash_attention.py | 11 +-------- 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 353e29b3d..2dff3bcbc 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -113,10 +113,6 @@ class FlashAttentionLoader(KernelLoader): ] -class FlashAttentionWithPaddingMaskLoader(KernelLoader): - REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension] - - class FlashAttentionWithCustomMaskLoader(KernelLoader): REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index f3f6e59d3..abc865a34 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -8,7 +8,6 @@ from colossalai.kernel.kernel_loader import ( FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionLoader, FlashAttentionWithCustomMaskLoader, - FlashAttentionWithPaddingMaskLoader, KernelLoader, ) @@ -65,15 +64,17 @@ class ColoAttention: half_dispatch_map = { None: FlashAttentionLoader(), AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(), - AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(), + AttnMaskType.PADDED: FlashAttentionLoader(), AttnMaskType.CAUSAL: FlashAttentionLoader(), - AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(), + AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(), } # fp32 float_dispatch_map = { None: FlashAttentionForFloatAndCustomMaskLoader(), AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(), + AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(), AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), + AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), } ColoAttention._kernel_dispatch_map = { torch.float16: half_dispatch_map, @@ -140,16 +141,22 @@ class ColoAttention: outputs["attention_mask_type"] = AttnMaskType.CAUSAL attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv) else: + assert q_padding_mask.shape == ( + b, + s_q, + ), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})" + max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: # self attention kv_padding_mask = q_padding_mask - assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == ( + max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices + else: + max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) + assert kv_padding_mask.shape == ( b, s_kv, - ), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" - attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device) - max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) - max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) + ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" + attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, diff --git a/tests/test_shardformer/test_flash_attention.py b/tests/test_shardformer/test_flash_attention.py index f9eab132f..9aa24a166 100644 --- a/tests/test_shardformer/test_flash_attention.py +++ b/tests/test_shardformer/test_flash_attention.py @@ -4,11 +4,7 @@ from copy import copy import torch from torch.testing import assert_close -from colossalai.kernel.kernel_loader import ( - FlashAttentionLoader, - FlashAttentionWithCustomMaskLoader, - FlashAttentionWithPaddingMaskLoader, -) +from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer.attn import invert_mask from colossalai.testing import clear_cache_before_run, parameterize @@ -119,11 +115,6 @@ def test_flash_attn_func(dtype: torch.dtype): if ext.is_available(): ext.assert_compatible() avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True)) - for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY: - ext = ext_cls() - if ext.is_available(): - ext.assert_compatible() - avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True)) test_sets = { "none": (lambda dtype: ({}, None), avail_attn_funcs), From 5d88ef1aaf5b1af4423f1f7a3a3bbec5cde13e17 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 25 Apr 2024 13:46:39 +0800 Subject: [PATCH 12/28] [shardformer] remove useless code (#5645) --- colossalai/shardformer/modeling/opt.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 8f841c8a6..81521c30b 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -58,20 +58,6 @@ class OPTPipelineForwards: under pipeline setting. """ - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - @staticmethod def opt_model_forward( self: OPTModel, From bbb2c21f16c16c0ab789f046a62f5bd2dfde57c1 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 25 Apr 2024 14:41:17 +0800 Subject: [PATCH 13/28] [shardformer] fix chatglm implementation (#5644) * [shardformer] fix chatglm policy * [shardformer] fix chatglm flash attn * [shardformer] update readme * [shardformer] fix chatglm init * [shardformer] fix chatglm test * [pipeline] fix chatglm merge batch --- colossalai/pipeline/schedule/one_f_one_b.py | 12 +- colossalai/shardformer/README.md | 121 ++++++++++++------ colossalai/shardformer/layer/normalization.py | 19 ++- colossalai/shardformer/modeling/chatglm2.py | 21 ++- .../shardformer/policies/auto_policy.py | 13 +- colossalai/shardformer/policies/chatglm2.py | 65 +++++++--- docs/source/en/features/shardformer.md | 7 - docs/source/zh-Hans/features/shardformer.md | 7 - tests/kit/model_zoo/transformers/chatglm2.py | 34 +++-- tests/test_shardformer/test_model/_utils.py | 2 +- .../test_model/test_shard_chatglm2.py | 9 +- 11 files changed, 193 insertions(+), 117 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 58008b98f..bfea8b67d 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -7,7 +7,7 @@ from torch.nn import Module from torch.utils._pytree import tree_map from colossalai.accelerator import get_accelerator -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils import get_current_device @@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_forward(output_obj) if outputs is not None: - outputs = merge_batch(outputs) + if isinstance(model, ModelWrapper): + model = model.unwrap() + batch_size_dim = getattr(model, "batch_size_dim", 0) + outputs = merge_batch(outputs, batch_size_dim) return {"loss": accum_loss, "outputs": outputs} def run_forward_backward( @@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) if outputs is not None: - outputs = merge_batch(outputs) + if isinstance(model, ModelWrapper): + model = model.unwrap() + batch_size_dim = getattr(model, "batch_size_dim", 0) + outputs = merge_batch(outputs, batch_size_dim) return {"loss": accum_loss, "outputs": outputs} def forward_backward_step( diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index c8670affb..d45421868 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer: - [x] Unit Testing - [ ] Policy Implementation -| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | -| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | -| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | -| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | -| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | -| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | -| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | -| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | -| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | +|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:| +| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | +| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | +| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | +| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | +| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | +| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | +| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | ## 💡 API Design @@ -391,6 +391,43 @@ _POLICY_LIST = { } ``` +#### How to support those models in huggingface model hub but not in the transformers library + +There are two cases: + +1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B. +2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b". + +Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`. + +Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself. + +E.g. for llama: +```python +policy[LlamaDecoderLayer] = ModulePolicyDescription(...) +``` + +for chatglm2: +```python +policy["GLMBlock"] = ModulePolicyDescription(...) +``` + +Then when registering such models in the autopolicy, we should follow below format: +```python +"transformers_modules..": PolicyLocation( + file_name="", class_name="" +) +``` + +As for chatglm2 model, it should be: +```python +"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( + file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" +) +``` + +When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy. + ### Write Your Unit Testing This section serves as the guideline for testing the `shardformer` module. @@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length. In the case of using 2 GPUs, the training times are as follows. -| N_CTX | org_model | shard_model | -| :------: | :-----: | :-----: | -| 256 | 11.2ms | 17.2ms | -| 512 | 9.8ms | 19.5ms | -| 1024 | 19.6ms | 18.9ms | -| 2048 | 46.6ms | 30.8ms | -| 4096 | 160.5ms | 90.4ms | +| N_CTX | org_model | shard_model | +|:-----:|:---------:|:-----------:| +| 256 | 11.2ms | 17.2ms | +| 512 | 9.8ms | 19.5ms | +| 1024 | 19.6ms | 18.9ms | +| 2048 | 46.6ms | 30.8ms | +| 4096 | 160.5ms | 90.4ms |

        @@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows. In the case of using 4 GPUs, the training times are as follows. -| N_CTX | org_model | shard_model | -| :------: | :-----: | :-----: | -| 256 | 10.0ms | 21.1ms | -| 512 | 11.5ms | 20.2ms | -| 1024 | 22.1ms | 20.6ms | -| 2048 | 46.9ms | 24.8ms | -| 4096 | 160.4ms | 68.0ms | +| N_CTX | org_model | shard_model | +|:-----:|:---------:|:-----------:| +| 256 | 10.0ms | 21.1ms | +| 512 | 11.5ms | 20.2ms | +| 1024 | 22.1ms | 20.6ms | +| 2048 | 46.9ms | 24.8ms | +| 4096 | 160.4ms | 68.0ms | @@ -475,10 +512,10 @@ warmup_fraction = 0.03 | accuracy | f1 | loss | GPU number | model sharded | -| :------: | :-----: | :-----: | :--------: | :---------: | -| 0.82971 | 0.87713 | 0.23194 | 4 | True | -| 0.83797 | 0.88006 | 0.22683 | 2 | True | -| 0.84521 | 0.88700 | 0.21822 | 1 | False | +|:--------:|:-------:|:-------:|:----------:|:-------------:| +| 0.82971 | 0.87713 | 0.23194 | 4 | True | +| 0.83797 | 0.88006 | 0.22683 | 2 | True | +| 0.84521 | 0.88700 | 0.21822 | 1 | False | Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index bba4bd070..5aa212600 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm): ) LazyInitContext.materialize(module) - # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm - if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]: - normalized_shape = module.weight.shape[0] - eps = module.variance_epsilon - elementwise_affine = True - else: - # get the attributes of the module - normalized_shape = module.normalized_shape - eps = module.eps - elementwise_affine = module.elementwise_affine + + # try to get normalized_shape, eps, elementwise_affine from the module + normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0]) + eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps + elementwise_affine = getattr(module, "elementwise_affine", True) rmsnorm = FusedRMSNormWithHook( - normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, ) rmsnorm.weight = module.weight diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 9207b34d0..53c151f02 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel def get_flash_core_attention_forward(): @@ -31,7 +30,12 @@ def get_flash_core_attention_forward(): device=query_layer.device, ) temp_mask = ( - torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device) + torch.ones( + query_layer.shape[2], + key_layer.shape[2], + dtype=torch.bool, + device=query_layer.device, + ) .tril(diagonal=0) .expand(query_layer.shape[0], 1, -1, -1) ) @@ -49,6 +53,7 @@ def get_flash_core_attention_forward(): attention_mask=attn_bias, attention_mask_type=attention_mask_type, dropout_p=dropout_p, + scale=1.0 / self.norm_factor, ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) @@ -115,7 +120,7 @@ class ChatGLMPipelineForwards: @staticmethod def chatglm_model_forward( - self: ChatGLMModel, + self: "ChatGLMModel", input_ids, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.BoolTensor] = None, @@ -194,7 +199,9 @@ class ChatGLMPipelineForwards: if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) @@ -224,7 +231,9 @@ class ChatGLMPipelineForwards: if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -254,7 +263,7 @@ class ChatGLMPipelineForwards: @staticmethod def chatglm_for_conditional_generation_forward( - self: ChatGLMForConditionalGeneration, + self: "ChatGLMForConditionalGeneration", input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 0991ace2c..d2b582af5 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -151,10 +151,10 @@ _POLICY_LIST = { file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy" ), # ChatGLM - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( + "transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation( file_name="chatglm2", class_name="ChatGLMModelPolicy" ), - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( + "transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" ), # Falcon @@ -202,6 +202,13 @@ def _fullname(obj): module = klass.__module__ if module == "builtins": return klass.__qualname__ # avoid outputs like 'builtins.str' + # patch custom models which are not in transformers + # it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub) + # or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory) + if module.startswith("transformers_modules"): + split_module = module.split(".") + if len(split_module) >= 2: + module = f"{split_module[0]}.{split_module[-1]}" return module + "." + klass.__qualname__ @@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy: if policy_location is None: raise NotImplementedError( - f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" + f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" ) else: policy = import_policy(policy_location) diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index f205835e7..4baf89f6a 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -7,7 +7,6 @@ from torch import Tensor import colossalai.shardformer.layer as col_nn from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel from ..modeling.chatglm2 import ( get_chatglm_sequence_parallel_forward_fn, @@ -17,7 +16,11 @@ from ..modeling.chatglm2 import ( from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"] +__all__ = [ + "ChatGLMPolicy", + "ChatGLMModelPolicy", + "ChatGLMForConditionalGenerationPolicy", +] class ChatGLMPolicy(Policy): @@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock - policy = {} embedding_cls = None @@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy): sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: - policy[GLMBlock] = ModulePolicyDescription( + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}" + attn_kwargs = { + "self_attention.qkv_hidden_size": ( + self.model.config.kv_channels * self.model.config.num_attention_heads * 3 + ) + // self.shard_config.tensor_parallel_size, + } + if self.model.config.multi_query_attention: + assert ( + self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0 + ), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}" + attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = ( + self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size + ) + attn_kwargs["self_attention.qkv_hidden_size"] = ( + self.model.config.kv_channels * self.model.config.num_attention_heads + + 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num + ) // self.shard_config.tensor_parallel_size + policy["GLMBlock"] = ModulePolicyDescription( attribute_replacement={ "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, @@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy): self.model.config.kv_channels * self.model.config.num_attention_heads ) // self.shard_config.tensor_parallel_size, - "self_attention.qkv_hidden_size": ( - self.model.config.kv_channels * self.model.config.num_attention_heads * 3 - ) - // self.shard_config.tensor_parallel_size, "self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels * self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + **attn_kwargs, }, param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap}, + kwargs={ + "seq_parallel_mode": sp_mode, + "seq_parallel_dim": 0, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="self_attention.dense", @@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy): ), ], policy=policy, - target_key=ChatGLMModel, + target_key="ChatGLMModel", ) # optimization configuration self.append_or_create_submodule_replacement( @@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy): ), ], policy=policy, - target_key=GLMBlock, + target_key="GLMBlock", ) if self.model.config.post_layer_norm: @@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy): ) ], policy=policy, - target_key=ChatGLMModel, + target_key="ChatGLMModel", ) # use flash attention @@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy): "forward": get_flash_core_attention_forward(), }, policy=policy, - target_key=CoreAttention, + target_key="CoreAttention", ) # use sequence parallel @@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy): self.append_or_create_method_replacement( description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, - target_key=ChatGLMModel, + target_key="ChatGLMModel", ) # use jit fused operator @@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy): "dropout_add": get_jit_fused_dropout_add_func(), }, policy=policy, - target_key=GLMBlock, + target_key="GLMBlock", ) return policy @@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy): stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy + model_cls="ChatGLMModel", + new_forward=ChatGLMPipelineForwards.chatglm_model_forward, + policy=policy, ) return policy @@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=ChatGLMForConditionalGeneration, + model_cls="ChatGLMForConditionalGeneration", new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, policy=policy, ) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 672945ea2..68d310f5c 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -310,13 +310,6 @@ if dist.get_world_size() > 1: 2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer. -3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through - ```python - from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel - ``` - when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes. - ## How Shardformer Works ### Main Idea diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index a7bcbd9f2..a42c7cc2e 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -303,13 +303,6 @@ if dist.get_world_size() > 1: 2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。 -3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类: - ```python - from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel - ``` - 并且使用这些导入的类初始化模型。 - ## Shardformer的工作原理 diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index 0b178d58c..f443553bb 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -1,7 +1,6 @@ import torch - -from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel +from torch.nn import init +from transformers import AutoConfig, AutoModelForCausalLM from ..registry import ModelAttribute, model_zoo @@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( ) loss_fn = lambda x: x["loss"] -config = ChatGLMConfig( +config = AutoConfig.from_pretrained( + "THUDM/chatglm2-6b", + trust_remote_code=True, num_layers=2, padded_vocab_size=65024, hidden_size=64, + ffn_hidden_size=214, num_attention_heads=8, kv_channels=16, rmsnorm=True, original_rope=True, use_cache=True, + multi_query_attention=False, torch_dtype=torch.float32, ) -infer_config = ChatGLMConfig( + +infer_config = AutoConfig.from_pretrained( + "THUDM/chatglm2-6b", + trust_remote_code=True, num_layers=2, padded_vocab_size=65024, hidden_size=128, @@ -60,18 +66,18 @@ infer_config = ChatGLMConfig( torch_dtype=torch.float32, ) -model_zoo.register( - name="transformers_chatglm", - model_fn=lambda: ChatGLMModel(config, empty_init=False), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_chatglm_model, - model_attribute=ModelAttribute(has_control_flow=True), -) + +def init_chatglm(): + model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True) + for m in model.modules(): + if m.__class__.__name__ == "RMSNorm": + init.ones_(m.weight) + return model + model_zoo.register( name="transformers_chatglm_for_conditional_generation", - model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), + model_fn=init_chatglm, data_gen_fn=data_gen_for_conditional_generation, output_transform_fn=output_transform_fn, loss_fn=loss_fn, diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index a77ba39a1..1835a5c8e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -227,7 +227,7 @@ def check_output_hidden_state( def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): - assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol) + assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol) def check_weight( diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 405ceba32..376d315c1 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -11,6 +11,7 @@ from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, check_all_grad_tensors, check_loss, + check_output_hidden_state, check_weight, get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, @@ -103,8 +104,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 # TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong - # if org_model.__class__.__name__ == "ChatGLMModel": - # check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) + if org_model.__class__.__name__ == "ChatGLMModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -177,14 +178,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, From 7ef91606e17cc1e991496c6cc74f73cbd42313ae Mon Sep 17 00:00:00 2001 From: Season Date: Thu, 25 Apr 2024 14:45:52 +0800 Subject: [PATCH 14/28] [Fix]: implement thread-safety singleton to avoid deadlock for very large-scale training scenarios (#5625) * implement thread-safety singleton * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor singleton implementation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/context/singleton_meta.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py index 3088b0dff..86a8aa5d8 100644 --- a/colossalai/context/singleton_meta.py +++ b/colossalai/context/singleton_meta.py @@ -1,22 +1,27 @@ +import threading + + class SingletonMeta(type): """ - The Singleton class can be implemented in different ways in Python. Some - possible methods include: base class, decorator, metaclass. We will use the - metaclass because it is best suited for this purpose. + Thread-safe Singleton Meta with double-checked locking. + Reference: https://en.wikipedia.org/wiki/Double-checked_locking """ _instances = {} + _lock = threading.Lock() def __call__(cls, *args, **kwargs): - """ - Possible changes to the value of the `__init__` argument do not affect - the returned instance. - """ + # First check (without locking) for performance reasons if cls not in cls._instances: - instance = super().__call__(*args, **kwargs) - cls._instances[cls] = instance + # Acquire a lock before proceeding to the second check + with cls._lock: + # Second check with lock held to ensure thread safety + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance else: assert ( len(args) == 0 and len(kwargs) == 0 - ), f"{cls.__name__} is a singleton class and a instance has been created." + ), f"{cls.__name__} is a singleton class and an instance has been created." + return cls._instances[cls] From 1b387ca9fe2fe7f90459537b0cc19d5bb4edbdc5 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 25 Apr 2024 15:19:30 +0800 Subject: [PATCH 15/28] [shardformer] refactor pipeline grad ckpt config (#5646) * [shardformer] refactor pipeline grad ckpt config * [shardformer] refactor pipeline grad ckpt config * [pipeline] fix stage manager --- .../booster/plugin/hybrid_parallel_plugin.py | 2 + colossalai/pipeline/stage_manager.py | 82 +++++++------------ colossalai/shardformer/modeling/llama.py | 2 + colossalai/shardformer/modeling/mistral.py | 2 + .../shardformer/policies/base_policy.py | 2 + .../shardformer/shard/grad_ckpt_config.py | 31 ++----- colossalai/shardformer/shard/shard_config.py | 12 +-- examples/language/llama/benchmark.py | 19 ++--- .../test_t5_pipeline_utils.py | 1 + .../test_whisper_pipeline_utils.py | 1 + .../test_model/test_shard_llama.py | 7 +- 11 files changed, 59 insertions(+), 102 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 95fb2def1..5237734f0 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -983,6 +983,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, make_vocab_size_divisible_by: int = 64, @@ -1056,6 +1057,7 @@ class HybridParallelPlugin(PipelinePluginBase): pipeline_axis=self.pp_axis, enable_interleave=pp_style == "interleaved", num_model_chunks=num_model_chunks, + num_layers_per_stage=num_layers_per_stage, ) if pp_style == "interleaved": diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index b0556669b..b7cbd67ab 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -27,16 +27,18 @@ class PipelineStageManager: pipeline_axis: int, enable_interleave: bool = False, num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, ) -> None: assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" - self.num_layers_per_stage = None - self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + if num_layers_per_stage is not None: + assert len(num_layers_per_stage) == self.num_stages + self.num_layers_per_stage = num_layers_per_stage # init prev and next coord coord = self.pg_mesh.coordinate() @@ -56,6 +58,8 @@ class PipelineStageManager: self.p2p_groups[tuple(ranks_in_group)] = group self.is_interleave = enable_interleave + # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers + self.num_model_chunks: int = num_model_chunks if enable_interleave: # use circle p2p communication # add the process group of the first rank and the last rank @@ -64,59 +68,11 @@ class PipelineStageManager: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers - self.num_model_chunks: int = num_model_chunks - # for shardformer, hold stage indices of model self.stage_indices: List[Tuple[int, int]] # for shardformer, hold model chunk id self.model_chunk_id: Optional[int] = None - @property - def control_distribute_layers(self) -> bool: - return self.num_layers_per_stage is not None - - def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None: - """Set the distribution configuration. - This allows user to customize the number of layers for each stage. - - Args: - num_model_layers (int): Number of layers in the model. - num_layers_per_stage (List[int]): Number of layers for each stage. - """ - assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage]) - assert sum(num_layers_per_stage) == num_model_layers - assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1) - self.num_model_layers = num_model_layers - self.num_layers_per_stage = num_layers_per_stage - - def distribute_layers( - self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None - ) -> List[int]: - """Divide layers into stages""" - num_stages = self.num_stages if num_stages is None else num_stages - num_model_chunks = ( - (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks - ) - - if self.control_distribute_layers: - assert num_layers == self.num_model_layers - return self.num_layers_per_stage - - else: - quotient = num_layers // (num_stages * num_model_chunks) - remainder = num_layers % (num_stages * num_model_chunks) - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages * num_model_chunks - - # deal with the rest layers - if remainder > 0: - start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage - def get_stage_index( self, layers_per_stage: List[int], @@ -139,9 +95,7 @@ class PipelineStageManager: """ stage = self.stage if stage is None else stage - num_model_chunks = ( - (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks - ) + num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks num_stages = self.num_stages if num_stages is None else num_stages num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) @@ -261,3 +215,25 @@ class PipelineStageManager: self.model_chunk_id = model_chunk_id yield self.model_chunk_id = old_model_chunk_id + + def distribute_layers( + self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None + ) -> List[int]: + if self.num_layers_per_stage is not None: + assert sum(self.num_layers_per_stage) == num_layers + return self.num_layers_per_stage + + num_stages = self.num_stages if num_stages is None else num_stages + num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks + quotient = num_layers // (num_stages * num_model_chunks) + remainder = num_layers % (num_stages * num_model_chunks) + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages * num_model_chunks + + # deal with the rest layers + if remainder > 0: + start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 0eb08a043..8a6a7cf17 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -168,8 +168,10 @@ class LlamaPipelineForwards: if shard_config.gradient_checkpoint_config is not None: num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, + num_stages=stage_manager.num_stages, num_layers=end_idx - start_idx, model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, ) assert num_ckpt_layers <= end_idx - start_idx diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index ac7845400..d5f00fc9f 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -129,8 +129,10 @@ class MistralForwards: if shard_config.gradient_checkpoint_config is not None: num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, + num_stages=stage_manager.num_stages, num_layers=end_idx - start_idx, model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, ) assert num_ckpt_layers <= end_idx - start_idx diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index e976672bb..282cf0464 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -28,6 +28,7 @@ class SubModuleReplacementDescription: kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception """ + suffix: str target_module: Union[ParallelModule, BaseLayerNorm] kwargs: Dict[str, Any] = None @@ -54,6 +55,7 @@ class ModulePolicyDescription: object which specifies the module to be replaced and the target module used to replacement. method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement """ + attribute_replacement: Dict[str, Any] = None param_replacement: List[Callable] = None sub_module_replacement: List[SubModuleReplacementDescription] = None diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py index 9fc857d19..9167da795 100644 --- a/colossalai/shardformer/shard/grad_ckpt_config.py +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -47,46 +47,33 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): ... """ - num_stages: Optional[int] = None - num_model_chunks: Optional[int] = None - num_model_layers: Optional[int] = None - num_layers_per_stage: Optional[List[int]] = None num_ckpt_layers_per_stage: Optional[List[int]] = None def __post_init__(self): - if self._enable_gradient_checkpointing_ratio: + if self._enable_customized_ckpt_layers_per_stage: + assert all([num_ckpt_layers >= 0 for num_ckpt_layers in self.num_ckpt_layers_per_stage]) + elif self._enable_gradient_checkpointing_ratio: if not (0 <= self.gradient_checkpointing_ratio <= 1): raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") - if self._enable_customized_ckpt_layers_per_stage: - assert ( - self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None - ) - assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks - assert all( - [0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage] - ) - self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers - @property def _enable_gradient_checkpointing_ratio(self) -> bool: return self.gradient_checkpointing_ratio is not None - @property - def _customize_num_layers_per_stage(self) -> bool: - return self.num_layers_per_stage is not None and self.num_model_layers is not None - @property def _enable_customized_ckpt_layers_per_stage(self) -> bool: return self.num_ckpt_layers_per_stage is not None - def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: + def get_num_ckpt_layers( + self, stage: int, num_stages: int, num_layers: int, model_chunk_id: int = 0, num_model_chunks: int = 1 + ) -> int: if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage: raise RuntimeError("No checkpointed layers information is provided") if self._enable_customized_ckpt_layers_per_stage: - assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks - num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] + assert len(self.num_ckpt_layers_per_stage) == num_stages * num_model_chunks + assert stage <= num_stages and model_chunk_id <= num_model_chunks + num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * num_stages] assert num_ckpt_layers <= num_layers return num_ckpt_layers else: diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index e20b8e239..98e72d8b3 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup from colossalai.pipeline.stage_manager import PipelineStageManager -from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig +from .grad_ckpt_config import GradientCheckpointConfig __all__ = ["ShardConfig"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @@ -105,16 +105,6 @@ class ShardConfig: else: self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) - if ( - self.pipeline_stage_manager is not None - and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig) - and self.gradient_checkpoint_config._customize_num_layers_per_stage - ): - self.pipeline_stage_manager.set_distribution_config( - self.gradient_checkpoint_config.num_model_layers, - self.gradient_checkpoint_config.num_layers_per_stage, - ) - def _turn_on_all_optimization(self): """ Turn on all optimization. diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index ff94891f5..d26975fc5 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -88,16 +88,15 @@ def main(): pass # ckpt config for LLaMA3-70B on 64 H100 GPUs - ckpt_config = ( - PipelineGradientCheckpointConfig( - num_stages=args.pp, - num_model_chunks=1, - num_model_layers=80, - num_layers_per_stage=[19, 20, 20, 21], - num_ckpt_layers_per_stage=[19, 19, 19, 13], - ) + hybrid_kwargs = ( + { + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_ckpt_layers_per_stage=[19, 19, 19, 13], + ), + "num_layers_per_stage": [19, 20, 20, 21], + } if args.custom_ckpt - else None + else {} ) # ============================== @@ -173,7 +172,7 @@ def main(): microbatch_size=args.mbs, precision="bf16", dp_outside=False, - gradient_checkpoint_config=ckpt_config, + **hybrid_kwargs, ) elif args.plugin == "3d_cpu": plugin = HybridParallelPlugin( diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 1b7b0073f..e2f71ff89 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager): def __init__(self): self.is_interleave = False self.num_layers_per_stage = None + self.num_model_chunks = 1 @property def num_stages(self): diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index 9f8c1ad32..d39c5ea91 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager): def __init__(self): self.is_interleave = False self.num_layers_per_stage = None + self.num_model_chunks = 1 @property def num_stages(self): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 2a10d86c7..394592688 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -217,9 +217,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig( - num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] - ), + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), }, { "tp_size": 4, @@ -303,9 +301,6 @@ def run_llama_test(test_config): "initial_scale": 1, "enable_gradient_checkpointing": True, "gradient_checkpoint_config": PipelineGradientCheckpointConfig( - num_stages=2, - num_model_chunks=2, - num_model_layers=8, num_ckpt_layers_per_stage=[0, 1, 2, 2], ), }, From 8b7d535977bf5b243741a8cdeb437cfdaf16c15e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 26 Apr 2024 11:52:27 +0800 Subject: [PATCH 16/28] fix gptj (#5652) --- colossalai/shardformer/policies/gptj.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 1280efaec..25e5b66dc 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -54,7 +54,6 @@ class GPTJPolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") - use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: @@ -78,7 +77,6 @@ class GPTJPolicy(Policy): suffix="attn.k_proj", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, "overlap": overlap, }, ), @@ -86,7 +84,6 @@ class GPTJPolicy(Policy): suffix="attn.q_proj", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, "overlap": overlap, }, ), @@ -94,24 +91,20 @@ class GPTJPolicy(Policy): suffix="attn.v_proj", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, "overlap": overlap, }, ), SubModuleReplacementDescription( suffix="attn.out_proj", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="mlp.fc_in", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="mlp.fc_out", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", From 2082852f3f14742013fbff18affd4cff3ccfa2b3 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 26 Apr 2024 14:03:12 +0800 Subject: [PATCH 17/28] [lazyinit] skip whisper test (#5653) --- tests/test_lazy/test_models.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index d0c4cd0a7..c85860a8d 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -7,21 +7,23 @@ from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo @pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") @pytest.mark.parametrize( "subset", - [COMMON_MODELS] - if IS_FAST_TEST - else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"], + ( + [COMMON_MODELS] + if IS_FAST_TEST + else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"] + ), ) @pytest.mark.parametrize("default_device", ["cpu", "cuda"]) -def test_torchvision_models_lazy_init(subset, default_device): +def test_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith( - ("transformers_vit", "transformers_blip2") + ("transformers_vit", "transformers_blip2", "transformers_whisper") ): continue check_lazy_init(entry, verbose=True, default_device=default_device) if __name__ == "__main__": - test_torchvision_models_lazy_init("transformers", "cpu") + test_models_lazy_init("transformers", "cpu") From b8a711aa2df86450c980f4b647199df2375dce33 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Fri, 26 Apr 2024 15:36:37 +0800 Subject: [PATCH 18/28] [news] llama3 and open-sora v1.1 (#5655) * [news] llama3 and open-sora v1.1 * [news] llama3 and open-sora v1.1 --- README.md | 4 +++- docs/README-zh-Hans.md | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c1e2da0d4..9e215df63 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,8 @@ ## Latest News +* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) +* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series) * [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here) * [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) @@ -131,7 +133,7 @@ distributed training and inference in a few lines. [Open-Sora](https://github.com/hpcaitech/Open-Sora):Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models [[code]](https://github.com/hpcaitech/Open-Sora) -[[blog]](https://hpc-ai.com/blog/open-sora-v1.0) +[[blog]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) [[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Open-Sora) [[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo) diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 7e0ed07fe..2e5437752 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -24,6 +24,8 @@ ## 新闻 +* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) +* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series) * [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here) * [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) @@ -126,7 +128,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 [Open-Sora](https://github.com/hpcaitech/Open-Sora):全面开源类Sora模型参数和所有训练细节 [[代码]](https://github.com/hpcaitech/Open-Sora) -[[博客]](https://hpc-ai.com/blog/open-sora-v1.0) +[[博客]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) [[模型权重]](https://huggingface.co/hpcai-tech/Open-Sora) [[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo) From 68ec99e946129298b2e6d8e6463886fe6b22a5df Mon Sep 17 00:00:00 2001 From: Tong Li Date: Fri, 26 Apr 2024 21:12:04 +0800 Subject: [PATCH 19/28] [hotfix] add soft link to support required files (#5661) --- examples/language/llama/benchmark.py | 3 --- examples/language/llama/data_utils.py | 1 + examples/language/llama/model_utils.py | 1 + examples/language/llama/performance_evaluator.py | 1 + 4 files changed, 3 insertions(+), 3 deletions(-) create mode 120000 examples/language/llama/data_utils.py create mode 120000 examples/language/llama/model_utils.py create mode 120000 examples/language/llama/performance_evaluator.py diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index d26975fc5..f457c08cd 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -19,9 +19,6 @@ from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer import PipelineGradientCheckpointConfig -from examples.language.data_utils import RandomDataset -from examples.language.model_utils import format_numel_str, get_model_numel -from examples.language.performance_evaluator import PerformanceEvaluator # ============================== # Constants diff --git a/examples/language/llama/data_utils.py b/examples/language/llama/data_utils.py new file mode 120000 index 000000000..2da9822df --- /dev/null +++ b/examples/language/llama/data_utils.py @@ -0,0 +1 @@ +../data_utils.py \ No newline at end of file diff --git a/examples/language/llama/model_utils.py b/examples/language/llama/model_utils.py new file mode 120000 index 000000000..73c6818a8 --- /dev/null +++ b/examples/language/llama/model_utils.py @@ -0,0 +1 @@ +../model_utils.py \ No newline at end of file diff --git a/examples/language/llama/performance_evaluator.py b/examples/language/llama/performance_evaluator.py new file mode 120000 index 000000000..f4736354b --- /dev/null +++ b/examples/language/llama/performance_evaluator.py @@ -0,0 +1 @@ +../performance_evaluator.py \ No newline at end of file From 4cfbf30a5e0d960ab31ebfda432986e15992fd36 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Sat, 27 Apr 2024 18:59:47 +0800 Subject: [PATCH 20/28] [release] update version (#5654) --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 449d7e73a..0f8268533 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.6 +0.3.7 From c1594e4bad5056d5500b7dbf1218241bb7e8eb84 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Sat, 27 Apr 2024 19:11:57 +0800 Subject: [PATCH 21/28] [devops] fix release docker ci (#5665) --- .github/workflows/release_docker_after_publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release_docker_after_publish.yml b/.github/workflows/release_docker_after_publish.yml index 6c8df9730..0792544bf 100644 --- a/.github/workflows/release_docker_after_publish.yml +++ b/.github/workflows/release_docker_after_publish.yml @@ -24,7 +24,7 @@ jobs: version=$(cat version.txt) tag=hpcaitech/colossalai:$version latest=hpcaitech/colossalai:latest - docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 --build-arg VERSION=v${version} -t $tag ./docker + docker build --build-arg VERSION=v${version} -t $tag ./docker docker tag $tag $latest echo "tag=${tag}" >> $GITHUB_OUTPUT echo "latest=${latest}" >> $GITHUB_OUTPUT From 14b0d4c7e5340b475d75319a43bbdb77b7fcc7a5 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 31 Oct 2023 15:19:37 +0800 Subject: [PATCH 22/28] [lora] add lora APIs for booster, support lora for TorchDDP (#4981) * add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirements --- colossalai/booster/booster.py | 57 +++++++++ colossalai/booster/plugin/gemini_plugin.py | 10 +- .../booster/plugin/hybrid_parallel_plugin.py | 10 +- .../booster/plugin/low_level_zero_plugin.py | 10 +- colossalai/booster/plugin/plugin_base.py | 12 +- colossalai/booster/plugin/torch_ddp_plugin.py | 32 +++++- .../booster/plugin/torch_fsdp_plugin.py | 10 +- .../checkpoint_io/checkpoint_io_base.py | 17 +++ .../checkpoint_io/general_checkpoint_io.py | 3 + requirements/requirements-test.txt | 3 +- tests/test_lora/test_torch_ddp_lora.py | 108 ++++++++++++++++++ 11 files changed, 265 insertions(+), 7 deletions(-) create mode 100644 tests/test_lora/test_torch_ddp_lora.py diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index d73bc5bab..c2a724084 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -8,6 +8,14 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +SUPPORT_PEFT = False +try: + import peft + + SUPPORT_PEFT = True +except ImportError: + pass + import colossalai.interface.pretrained as pretrained_utils from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -221,6 +229,38 @@ class Booster: assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync." return self.plugin.no_sync(model, optimizer) + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None + ) -> nn.Module: + """ + Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. + Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft. + + Args: + model (nn.Module): The model to be appended with LoRA modules. + pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory + or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub. + When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None. + lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None. + """ + if not SUPPORT_PEFT: + raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!") + + assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided." + assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora." + if pretrained_dir is None: + assert ( + lora_config is not None + ), "Please provide configuration for Lora when pretrained directory path isn't passed in." + assert isinstance( + lora_config, peft.LoraConfig + ), "The passed in configuration should be an instance of peft.LoraConfig." + if lora_config is None: + assert ( + pretrained_dir is not None + ), "Please provide pretrained directory path if not passing in lora configuration." + return self.plugin.enable_lora(model, pretrained_dir, lora_config) + def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: """Load model from checkpoint. @@ -323,3 +363,20 @@ class Booster: checkpoint (str): Path to the checkpoint. It must be a local file path. """ self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint) + + def save_lora_as_pretrained( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False + ) -> None: + """ + Save the lora adapters and adapter configuration file to a pretrained checkpoint directory. + + Args: + model (Union[nn.Module, ModelWrapper]): A model boosted by Booster. + checkpoint (str): Path to the checkpoint directory. It must be a local path. + use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False. + """ + if not SUPPORT_PEFT: + raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!") + assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided." + assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora." + self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index a67ca18a3..964cd302a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -3,7 +3,7 @@ import logging import os import random from pathlib import Path -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple import numpy as np import torch @@ -444,6 +444,9 @@ class GeminiPlugin(DPPluginBase): def support_no_sync(self) -> bool: return False + def support_lora(self) -> bool: + return False + def control_precision(self) -> bool: return True @@ -573,3 +576,8 @@ class GeminiPlugin(DPPluginBase): def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + raise NotImplementedError diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 5237734f0..97057481e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -4,7 +4,7 @@ import warnings from contextlib import contextmanager from functools import partial from types import MethodType -from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np import torch @@ -1156,6 +1156,9 @@ class HybridParallelPlugin(PipelinePluginBase): def support_no_sync(self) -> bool: return True + def support_lora(self) -> bool: + return False + def control_checkpoint_io(self) -> bool: return True @@ -1356,3 +1359,8 @@ class HybridParallelPlugin(PipelinePluginBase): self.zero_stage != 2 ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + + def enable_lora( + self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> Module: + raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index d21496f0b..243051895 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -3,7 +3,7 @@ import os from functools import partial from pathlib import Path from types import MethodType -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -296,6 +296,9 @@ class LowLevelZeroPlugin(DPPluginBase): def support_no_sync(self) -> bool: return self.stage == 1 + def support_lora(self) -> bool: + return False + def control_precision(self) -> bool: return True @@ -337,3 +340,8 @@ class LowLevelZeroPlugin(DPPluginBase): def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(optimizer, LowLevelZeroOptimizer) return optimizer.no_sync() + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + raise NotImplementedError diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 4e570cbe8..6dc0c560d 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple import torch.nn as nn from torch.optim import Optimizer @@ -33,6 +33,10 @@ class Plugin(ABC): def support_no_sync(self) -> bool: pass + @abstractmethod + def support_lora(self) -> bool: + pass + @abstractmethod def configure( self, @@ -63,6 +67,12 @@ class Plugin(ABC): Context manager to disable gradient synchronization. """ + @abstractmethod + def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module: + """ + Add LoRA modules to the model passed in. Should only be called in booster.enable_lora(). + """ + @abstractmethod def prepare_dataloader( self, diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 738634473..9ba520de2 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP @@ -116,6 +116,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix) + def save_lora_as_pretrained( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False + ) -> None: + """ + Save the lora adapters and adapter configuration file to checkpoint directory. + """ + from peft import PeftModel + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + if self.coordinator.is_master(): + peft_model = model.unwrap() + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) + class TorchDDPModel(ModelWrapper): def __init__(self, module: nn.Module, *args, **kwargs) -> None: @@ -173,6 +189,9 @@ class TorchDDPPlugin(DPPluginBase): def support_no_sync(self) -> bool: return True + def support_lora(self) -> bool: + return True + def control_precision(self) -> bool: return False @@ -216,3 +235,14 @@ class TorchDDPPlugin(DPPluginBase): def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin." return model.module.no_sync() + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + from peft import PeftModel, get_peft_model + + assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model." + if pretrained_dir is None: + return get_peft_model(model, lora_config) + else: + return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 0aa0caa9a..cd2f9e840 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -2,7 +2,7 @@ import logging import os import warnings from pathlib import Path -from typing import Callable, Iterable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -318,6 +318,9 @@ class TorchFSDPPlugin(DPPluginBase): def support_no_sync(self) -> bool: return False + def support_lora(self) -> bool: + return False + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError("Torch fsdp no_sync func not supported yet.") @@ -361,3 +364,8 @@ class TorchFSDPPlugin(DPPluginBase): def get_checkpoint_io(self) -> CheckpointIO: return TorchFSDPCheckpointIO() + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + raise NotImplementedError diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 712324215..949ba4d44 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -335,3 +335,20 @@ class CheckpointIO(ABC): """ state_dict = torch.load(checkpoint) lr_scheduler.load_state_dict(state_dict) + + # ================================================================================ + # Abstract method for lora saving implementation. + # ================================================================================ + + @abstractmethod + def save_lora_as_pretrained( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False + ) -> None: + """ + Save the lora adapters and adapter configuration file to a pretrained checkpoint directory. + + Args: + model (Union[nn.Module, ModelWrapper]): A model boosted by Booster. + checkpoint (str): Path to the checkpoint directory. It must be a local path. + use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False. + """ diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index a652d9b45..b9253a56d 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -228,3 +228,6 @@ class GeneralCheckpointIO(CheckpointIO): self.__class__.__name__, "\n\t".join(error_msgs) ) ) + + def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None: + raise NotImplementedError diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 0b15b9311..de7fe8a21 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -5,7 +5,7 @@ git+https://github.com/hpcaitech/pytest-testmon torchvision timm titans -torchaudio +torchaudio>=0.13.1 torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes. torchrec==0.2.0 contexttimer @@ -18,4 +18,5 @@ flash_attn datasets pydantic ray +peft #auto-gptq now not support torch1.12 diff --git a/tests/test_lora/test_torch_ddp_lora.py b/tests/test_lora/test_torch_ddp_lora.py new file mode 100644 index 000000000..b3169bf86 --- /dev/null +++ b/tests/test_lora/test_torch_ddp_lora.py @@ -0,0 +1,108 @@ +import copy +import os + +import torch +from peft import LoraConfig +from torch import distributed as dist +from torch.optim import AdamW + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.testing import ( + assert_equal, + assert_not_equal, + check_state_dict_equal, + clear_cache_before_run, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_checkpoint_io.utils import shared_tempdir + + +@clear_cache_before_run() +def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): + model = model_fn() + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = booster.enable_lora(model, lora_config=lora_config) + model_copy = copy.deepcopy(model) + + optimizer = AdamW(model.parameters(), lr=0.001) + criterion = loss_fn + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} + + output = model(**data) + output = output_transform_fn(output) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + + for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()): + if "lora_" in n1: + # lora modules require gradients, thus updated + assert p1.requires_grad + assert_not_equal(p1.to(p2.device), p2) + else: + if not p1.requires_grad: + assert_equal(p1.to(p2.device), p2) + + +@clear_cache_before_run() +def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): + plugin = TorchDDPPlugin() + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + + model_save = model_fn() + model_load = copy.deepcopy(model_save) + + booster = Booster(plugin=plugin) + model_save = booster.enable_lora(model_save, lora_config=lora_config) + model_save, _, _, _, _ = booster.boost(model_save) + + with shared_tempdir() as tempdir: + lora_ckpt_path = os.path.join(tempdir, "ckpt") + booster.save_lora_as_pretrained(model_save, lora_ckpt_path) + dist.barrier() + + # The Lora checkpoint should be small in size + checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) + assert checkpoint_size_mb < 1 + + model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path) + model_load, _, _, _, _ = booster.boost(model_load) + + check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) + + +def run_lora_test(): + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + task_type = None + if name == "transformers_llama_for_casual_lm": + task_type = "CAUSAL_LM" + if name == "transformers_llama_for_sequence_classification": + task_type = "SEQ_CLS" + check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) + check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_lora_test() + + +@rerun_if_address_is_in_use() +def test_torch_ddp_lora(): + spawn(run_dist, 2) From 8954a0c2e2c43e04d281853f2ac771f30ec41053 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 21 Dec 2023 17:01:01 +0800 Subject: [PATCH 23/28] [LowLevelZero] low level zero support lora (#5153) * low level zero support lora low level zero support lora * add checkpoint test * add checkpoint test * fix * fix * fix * fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * test ci * git # This is a combination of 3 commits. Update low_level_zero_plugin.py Update low_level_zero_plugin.py fix fix fix * fix naming fix naming fix naming fix --- .../booster/plugin/low_level_zero_plugin.py | 103 +++++++++++++++++- colossalai/pipeline/p2p.py | 12 ++ .../low_level/bookkeeping/gradient_store.py | 3 + requirements/requirements-test.txt | 2 +- requirements/requirements.txt | 1 + .../test_plugin/test_dp_plugin_base.py | 8 +- .../test_plugin/test_low_level_zero_plugin.py | 40 ++++++- .../test_low_level_zero_checkpoint_io.py | 103 ++++++++++++++++++ 8 files changed, 264 insertions(+), 8 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 243051895..6bc9ba0e7 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,5 +1,7 @@ +import enum import logging import os +import warnings from functools import partial from pathlib import Path from types import MethodType @@ -7,6 +9,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple import torch import torch.nn as nn +from torch.nn import Parameter from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map @@ -42,6 +45,12 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] +class OptimizerParamCheckState(enum.Enum): + ORIGIN_PARAM_FINDED = 0 + ORIGIN_PARAM_NOT_FIND = -1 + LORA_PARM_EXISTED = -2 + + class LowLevelZeroModel(ModelWrapper, AMPModelMixin): def __init__(self, module: nn.Module, precision: str) -> None: super().__init__(module) @@ -209,6 +218,19 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) model.update_master_params() + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + from peft import PeftModel + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + peft_model = model.unwrap() + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) + class LowLevelZeroPlugin(DPPluginBase): """ @@ -288,6 +310,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload=cpu_offload, master_weights=master_weights, ) + self.lora_enabled = False self.verbose = verbose # set class name with stage, for better error message @@ -311,6 +334,72 @@ class LowLevelZeroPlugin(DPPluginBase): def supported_devices(self) -> List[str]: return ["cuda", "npu"] + def support_lora(self) -> bool: + return True + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + from peft import PeftModel, get_peft_model + + assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." + self.lora_enabled = True + warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + + if pretrained_dir is None: + peft_model = get_peft_model(model, lora_config) + else: + peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) + return peft_model + + def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter): + origin_param_id = id(origin_param) + for group_id, param_group in enumerate(optimizer.param_groups): + for p in param_group["params"]: + if id(p) == origin_param_id: + return group_id + return -1 + + def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter): + origin_param_id = id(origin_param) + lora_param_id = id(lora_param) + target_group_id = None + for group_id, param_group in enumerate(optimizer.param_groups): + for p in param_group["params"]: + if id(p) == lora_param_id: + # check if the lora parameter exists. + return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED + if id(p) == origin_param_id: + target_group_id = group_id + if target_group_id is not None: + return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED + else: + return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND + + def add_lora_params_to_optimizer(self, model, optimizer): + """add lora parameters to optimizer""" + name2param = {} + for name, param in model.named_parameters(): + name2param[name] = param + + for name, param in name2param.items(): + if "lora_A" in name or "lora_B" in name: + origin_key = name.replace("lora_A.", "") + origin_key = origin_key.replace("lora_B.", "") + origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer") + origin_param = name2param[origin_key] + group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) + if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: + warnings.warn( + "Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." + ) + elif ( + check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED + and group_id is not None + and group_id >= 0 + ): + optimizer.param_groups[group_id]["params"].append(param) + def configure( self, model: nn.Module, @@ -319,6 +408,15 @@ class LowLevelZeroPlugin(DPPluginBase): dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + if self.lora_enabled: + from peft import PeftModel + + assert isinstance( + model, PeftModel + ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True" + if optimizer is not None: + self.add_lora_params_to_optimizer(model, optimizer) + if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.precision) @@ -340,8 +438,3 @@ class LowLevelZeroPlugin(DPPluginBase): def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(optimizer, LowLevelZeroOptimizer) return optimizer.no_sync() - - def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None - ) -> nn.Module: - raise NotImplementedError diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 5588aa578..1b55b140c 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -45,6 +45,18 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - return unpickle +def check_for_nccl_backend(group): + pg = group or c10d._get_default_group() + # Gate PG wrapper check on Gloo availability. + if c10d._GLOO_AVAILABLE: + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, c10d._ProcessGroupWrapper): + pg = pg.wrapped_pg + + return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL + + # NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use def _broadcast_object_list( object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 73a1db5a0..6d4fcbb86 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -82,6 +82,9 @@ class GradientStore(BaseStore): """ grad_list = [] + # When using LoRa and the user sets multiple param_groups, it is possible that some param_groups have no parameters with gradients. + if group_id not in self._grads_of_params.keys(): + return grad_list for param_grads in self._grads_of_params[group_id].values(): grad_list.append(param_grads[self._working_index]) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index de7fe8a21..58c7f780f 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,5 +18,5 @@ flash_attn datasets pydantic ray -peft +peft>=0.7.1 #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index d307312de..815b23fc7 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -17,3 +17,4 @@ sentencepiece google protobuf transformers==4.36.2 +peft>=0.7.1 diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py index 0ac9d0f6d..fceb623fe 100644 --- a/tests/test_booster/test_plugin/test_dp_plugin_base.py +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, List, Tuple, Union +from typing import Callable, Dict, Iterator, List, Tuple, Union import torch import torch.distributed as dist @@ -51,6 +51,12 @@ class DPPluginWrapper(DPPluginBase): def no_sync(self, model: nn.Module) -> Iterator[None]: pass + def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module: + pass + + def support_lora(self) -> bool: + pass + def check_dataloader_sharding(): plugin = DPPluginWrapper() diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 861fa0131..cbfad6ef7 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -2,6 +2,7 @@ from typing import Optional import torch import torch.distributed as dist +from peft import LoraConfig from torch.optim import Adam import colossalai @@ -22,13 +23,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"] @clear_cache_before_run() -def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: +def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]: device = get_accelerator().get_current_device() try: plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) model = model_fn() optimizer = Adam(model.parameters(), lr=1e-3) + + if lora_config is not None: + model = booster.enable_lora(model, lora_config=lora_config) + criterion = lambda x: x.mean() data = data_gen_fn() @@ -48,6 +53,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: except Exception as e: return repr(e) + # raise e @parameterize("stage", [2]) @@ -91,10 +97,42 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) +@parameterize("stage", [2]) +@parameterize("model_name", ["transformers_llama"]) +def check_low_level_zero_lora(stage, model_name, early_stop: bool = True): + passed_models = [] + failed_info = {} # (model_name, error) pair + + sub_model_zoo = model_zoo.get_sub_registry(model_name) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + task_type = None + if name == "transformers_llama_for_casual_lm": + task_type = "CAUSAL_LM" + if name == "transformers_llama_for_sequence_classification": + task_type = "SEQ_CLS" + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config) + + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) + + def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_plugin(early_stop=early_stop) + check_low_level_zero_lora(early_stop=early_stop) @rerun_if_address_is_in_use() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index e7f44f97e..4073cae0c 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -1,5 +1,9 @@ +from copy import deepcopy +from typing import Optional + import torch import torch.distributed as dist +from peft import LoraConfig from torchvision.models import resnet18 from utils import shared_tempdir @@ -15,6 +19,7 @@ from colossalai.testing import ( spawn, ) from colossalai.zero import LowLevelZeroOptimizer +from tests.kit.model_zoo import model_zoo # stage 1 and 2 process the optimizer/mode the same way @@ -69,9 +74,107 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): torch.cuda.empty_cache() +def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]: + try: + plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload) + new_plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload) + booster = Booster(plugin=plugin) + new_booster = Booster(plugin=new_plugin) + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + new_model = deepcopy(model) + new_optimizer = HybridAdam(new_model.parameters(), lr=1e-3) + model = booster.enable_lora(model, lora_config=lora_config) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + booster.save_lora_as_pretrained(model, model_ckpt_path) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False) + new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) + + # check master weight + assert isinstance(new_optimizer, LowLevelZeroOptimizer) + working_param_id_set = set(id(p) for p in new_model.parameters()) + for p_id, master_param in new_optimizer._param_store.working_to_master_param.items(): + assert p_id in working_param_id_set + working_param = new_optimizer._param_store.master_to_working_param[id(master_param)] + padding = new_optimizer._param_store.get_param_padding_size(working_param) + padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) + working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] + assert torch.equal( + working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) + ) + + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) + + except Exception as e: + # return repr(e) + raise e + + +@clear_cache_before_run() +@parameterize("stage", [2]) +@parameterize("shard", [True, False]) +@parameterize("offload", [False, True]) +@parameterize("model_name", ["transformers_llama"]) +def check_low_level_zero_lora_checkpointIO( + stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True +): + passed_models = [] + failed_info = {} # (model_name, error) pair + + sub_model_zoo = model_zoo.get_sub_registry(model_name) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != "transformers_llama": + continue + task_type = None + if name == "transformers_llama_for_casual_lm": + task_type = "CAUSAL_LM" + if name == "transformers_llama_for_sequence_classification": + task_type = "SEQ_CLS" + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + err = run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config) + + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) + + def run_dist(rank, world_size, port): colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_checkpointIO() + check_low_level_zero_lora_checkpointIO() torch.cuda.empty_cache() From 91fa55377505721efc68dbf750dc01aa5c142d3e Mon Sep 17 00:00:00 2001 From: linsj20 Date: Wed, 17 Apr 2024 15:03:31 +0800 Subject: [PATCH 24/28] [Feature] qlora support (#5586) * [feature] qlora support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qlora follow commit * migrate qutization folder to colossalai/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- LICENSE | 15 + .../colossal_llama/dataset/loader.py | 16 +- applications/Colossal-LLaMA/train.py | 8 +- colossalai/booster/booster.py | 23 +- .../booster/plugin/low_level_zero_plugin.py | 10 +- colossalai/booster/plugin/torch_ddp_plugin.py | 10 +- colossalai/inference/README.md | 38 +-- colossalai/quantization/__init__.py | 7 + colossalai/quantization/bnb.py | 321 ++++++++++++++++++ colossalai/quantization/bnb_config.py | 113 ++++++ colossalai/zero/low_level/low_level_optim.py | 7 +- requirements/requirements.txt | 1 + tests/test_lora/test_lora.py | 106 ++++++ tests/test_lora/test_torch_ddp_lora.py | 108 ------ 14 files changed, 640 insertions(+), 143 deletions(-) create mode 100644 colossalai/quantization/__init__.py create mode 100644 colossalai/quantization/bnb.py create mode 100644 colossalai/quantization/bnb_config.py create mode 100644 tests/test_lora/test_lora.py delete mode 100644 tests/test_lora/test_torch_ddp_lora.py diff --git a/LICENSE b/LICENSE index 47197afe6..f0b2ffa97 100644 --- a/LICENSE +++ b/LICENSE @@ -552,3 +552,18 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ---------------- LICENSE FOR Hugging Face accelerate ---------------- + + Copyright 2021 The HuggingFace Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py b/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py index 327651f4e..abe0fd51a 100644 --- a/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py @@ -80,15 +80,19 @@ class DataCollatorForSupervisedDataset(object): # `List[torch.Tensor]` batch_input_ids = [ - torch.LongTensor(instance["input_ids"][: self.max_length]) - if len(instance["input_ids"]) > self.max_length - else torch.LongTensor(instance["input_ids"]) + ( + torch.LongTensor(instance["input_ids"][: self.max_length]) + if len(instance["input_ids"]) > self.max_length + else torch.LongTensor(instance["input_ids"]) + ) for instance in instances ] batch_labels = [ - torch.LongTensor(instance["labels"][: self.max_length]) - if len(instance["labels"]) > self.max_length - else torch.LongTensor(instance["labels"]) + ( + torch.LongTensor(instance["labels"][: self.max_length]) + if len(instance["labels"]) > self.max_length + else torch.LongTensor(instance["labels"]) + ) for instance in instances ] diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index dcd7be9f4..37e4fcc80 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -253,9 +253,11 @@ def main() -> None: coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") optimizer = HybridAdam( - model_params=filter(lambda p: p.requires_grad, model.parameters()) - if args.freeze_non_embeds_params - else model.parameters(), + model_params=( + filter(lambda p: p.requires_grad, model.parameters()) + if args.freeze_non_embeds_params + else model.parameters() + ), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay, diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index c2a724084..56d8a0935 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -19,6 +19,7 @@ except ImportError: import colossalai.interface.pretrained as pretrained_utils from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory @@ -230,7 +231,12 @@ class Booster: return self.plugin.no_sync(model, optimizer) def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: "peft.LoraConfig" = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, + quantize=False, ) -> nn.Module: """ Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. @@ -259,7 +265,20 @@ class Booster: assert ( pretrained_dir is not None ), "Please provide pretrained directory path if not passing in lora configuration." - return self.plugin.enable_lora(model, pretrained_dir, lora_config) + if quantize is True: + if bnb_quantization_config is not None: + warnings.warn( + "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." + ) + else: + bnb_quantization_config = BnbQuantizationConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + + return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config) def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: """Load model from checkpoint. diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 6bc9ba0e7..be75bebac 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -28,6 +28,7 @@ from colossalai.checkpoint_io.utils import ( sharded_optimizer_loading_epilogue, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.zero import LowLevelZeroOptimizer from .dp_plugin_base import DPPluginBase @@ -338,7 +339,11 @@ class LowLevelZeroPlugin(DPPluginBase): return True def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, ) -> nn.Module: from peft import PeftModel, get_peft_model @@ -346,6 +351,9 @@ class LowLevelZeroPlugin(DPPluginBase): self.lora_enabled = True warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + if pretrained_dir is None: peft_model = get_peft_model(model, lora_config) else: diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 9ba520de2..482cc4e98 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig, quantize_model from .dp_plugin_base import DPPluginBase @@ -237,10 +238,17 @@ class TorchDDPPlugin(DPPluginBase): return model.module.no_sync() def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, ) -> nn.Module: from peft import PeftModel, get_peft_model + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model." if pretrained_dir is None: return get_peft_model(model, lora_config) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 287853a86..c2b808155 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -165,7 +165,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc ##### Llama | batch_size | 8 | 16 | 32 | -| :---------------------: | :----: | :----: | :----: | +|:-----------------------:|:------:|:------:|:------:| | hugging-face torch fp16 | 199.12 | 246.56 | 278.4 | | colossal-inference | 326.4 | 582.72 | 816.64 | @@ -174,7 +174,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc #### Bloom | batch_size | 8 | 16 | 32 | -| :---------------------: | :----: | :----: | :----: | +|:-----------------------:|:------:|:------:|:------:| | hugging-face torch fp16 | 189.68 | 226.66 | 249.61 | | colossal-inference | 323.28 | 538.52 | 611.64 | @@ -187,40 +187,40 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t #### A10 7b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| -| :-------------------------: | :---: | :---:| :---: | :---: | :---: | :---: | -| Pipeline Inference | 40.35 | 77.10| 139.03| 232.70| 257.81| OOM | -| Hugging Face | 41.43 | 65.30| 91.93 | 114.62| OOM | OOM | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) | +|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:| +| Pipeline Inference | 40.35 | 77.10 | 139.03 | 232.70 | 257.81 | OOM | +| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM | ![ppllama7b](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama7b.png) #### A10 13b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | -| :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | -| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) | +|:----------------------------:|:-----:|:-----:|:-----:|:-----:| +| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | +| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | ![ppllama13](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama13b.png) #### A800 7b, fp16 -| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | -| :---: | :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | -| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +|:----------------------------:|:-----:|:------:|:------:|:------:|:------:| +| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | +| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | ![ppllama7b_a800](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a800-llama7b.png) ### Quantization LLama -| batch_size | 8 | 16 | 32 | -| :---------------------: | :----: | :----: | :----: | -| auto-gptq | 199.20 | 232.56 | 253.26 | -| smooth-quant | 142.28 | 222.96 | 300.59 | -| colossal-gptq | 231.98 | 388.87 | 573.03 | +| batch_size | 8 | 16 | 32 | +|:-------------:|:------:|:------:|:------:| +| auto-gptq | 199.20 | 232.56 | 253.26 | +| smooth-quant | 142.28 | 222.96 | 300.59 | +| colossal-gptq | 231.98 | 388.87 | 573.03 | ![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-quant.png) diff --git a/colossalai/quantization/__init__.py b/colossalai/quantization/__init__.py new file mode 100644 index 000000000..e9707b479 --- /dev/null +++ b/colossalai/quantization/__init__.py @@ -0,0 +1,7 @@ +from .bnb import quantize_model +from .bnb_config import BnbQuantizationConfig + +__all__ = [ + "BnbQuantizationConfig", + "quantize_model", +] diff --git a/colossalai/quantization/bnb.py b/colossalai/quantization/bnb.py new file mode 100644 index 000000000..fa214116a --- /dev/null +++ b/colossalai/quantization/bnb.py @@ -0,0 +1,321 @@ +# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py + +import logging + +import torch +import torch.nn as nn + +from .bnb_config import BnbQuantizationConfig + +try: + import bitsandbytes as bnb + + IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" + IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" +except ImportError: + pass + + +logger = logging.getLogger(__name__) + + +def quantize_model( + model: torch.nn.Module, + bnb_quantization_config: BnbQuantizationConfig, +): + """ + This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`. + We will quantize the model and put the model on the GPU. + + Args: + model (`torch.nn.Module`): + Input model. The model already loaded + bnb_quantization_config (`BnbQuantizationConfig`): + The bitsandbytes quantization parameters + + Returns: + `torch.nn.Module`: The quantized model + """ + + load_in_4bit = bnb_quantization_config.load_in_4bit + load_in_8bit = bnb_quantization_config.load_in_8bit + + if load_in_8bit and not IS_8BIT_BNB_AVAILABLE: + raise ImportError( + "You have a version of `bitsandbytes` that is not compatible with 8bit quantization," + " make sure you have the latest version of `bitsandbytes` installed." + ) + if load_in_4bit and not IS_4BIT_BNB_AVAILABLE: + raise ValueError( + "You have a version of `bitsandbytes` that is not compatible with 4bit quantization," + "make sure you have the latest version of `bitsandbytes` installed." + ) + + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + if bnb_quantization_config.skip_modules is None: + bnb_quantization_config.skip_modules = get_keys_to_not_convert(model) + + modules_to_not_convert = bnb_quantization_config.skip_modules + + # We add the modules we want to keep in full precision + if bnb_quantization_config.keep_in_fp32_modules is None: + bnb_quantization_config.keep_in_fp32_modules = [] + keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules + + # compatibility with peft + model.is_loaded_in_4bit = load_in_4bit + model.is_loaded_in_8bit = load_in_8bit + + # assert model_device is cuda + model_device = next(model.parameters()).device + + model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert) + + # convert param to the right dtype + dtype = bnb_quantization_config.torch_dtype + for name, param in model.state_dict().items(): + if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): + param.to(torch.float32) + if param.dtype != torch.float32: + name = name.replace(".weight", "").replace(".bias", "") + param = getattr(model, name, None) + if param is not None: + param.to(torch.float32) + elif torch.is_floating_point(param): + param.to(dtype) + if model_device.type == "cuda": + # move everything to cpu in the first place because we can't do quantization if the weights are already on cuda + model.cuda(torch.cuda.current_device()) + torch.cuda.empty_cache() + elif torch.cuda.is_available(): + model.to(torch.cuda.current_device()) + logger.info( + f"The model device type is {model_device.type}. However, cuda is needed for quantization." + "We move the model to cuda." + ) + else: + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + return model + + +def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit` + modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[str]`): + Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for + numerical stability reasons. + current_key_name (`List[str]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert. + """ + + if modules_to_not_convert is None: + modules_to_not_convert = [] + + model, has_been_replaced = _replace_with_bnb_layers( + model, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + return model + + +def _replace_with_bnb_layers( + model, + bnb_quantization_config, + modules_to_not_convert=None, + current_key_name=None, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily + + has_been_replaced = False + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + proceed = True + for key in modules_to_not_convert: + if ( + (key in current_key_name_str) and (key + "." in current_key_name_str) + ) or key == current_key_name_str: + proceed = False + break + if proceed: + # Load bnb module with empty weight and replace ``nn.Linear` module + if bnb_quantization_config.load_in_8bit: + bnb_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=bnb_quantization_config.llm_int8_threshold, + ) + elif bnb_quantization_config.load_in_4bit: + bnb_module = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + module.bias is not None, + bnb_quantization_config.bnb_4bit_compute_dtype, + compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant, + quant_type=bnb_quantization_config.bnb_4bit_quant_type, + ) + else: + raise ValueError("load_in_8bit and load_in_4bit can't be both False") + bnb_module.weight.data = module.weight.data + bnb_module.weight.skip_zero_check = True + if module.bias is not None: + bnb_module.bias.data = module.bias.data + bnb_module.bias.skip_zero_check = True + bnb_module.requires_grad_(False) + setattr(model, name, bnb_module) + has_been_replaced = True + if len(list(module.children())) > 0: + _, _has_been_replaced = _replace_with_bnb_layers( + module, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + has_been_replaced = has_been_replaced | _has_been_replaced + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def get_keys_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + # Create a copy of the model + # with init_empty_weights(): + # tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model = model + + tied_params = find_tied_parameters(tied_model) + # For compatibility with Accelerate < 0.18 + if isinstance(tied_params, dict): + tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) + else: + tied_keys = sum(tied_params, []) + has_tied_params = len(tied_keys) > 0 + + # Check if it is a base model + is_base_model = False + if hasattr(model, "base_model_prefix"): + is_base_model = not hasattr(model, model.base_model_prefix) + + # Ignore this for base models (BertModel, GPT2Model, etc.) + if (not has_tied_params) and is_base_model: + return [] + + # otherwise they have an attached head + list_modules = list(model.named_children()) + list_last_module = [list_modules[-1][0]] + + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = list(set(tied_keys)) + list(intersection) + + # remove ".weight" from the keys + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names + + +def find_tied_parameters(model: nn.Module, **kwargs): + """ + Find the tied parameters in a given model. + + + + The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore + them. + + + + Args: + model (`torch.nn.Module`): The model to inspect. + + Returns: + List[List[str]]: A list of lists of parameter names being all tied together. + + Example: + + ```py + >>> from collections import OrderedDict + >>> import torch.nn as nn + + >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) + >>> model.linear2.weight = model.linear1.weight + >>> find_tied_parameters(model) + [['linear1.weight', 'linear2.weight']] + ``` + """ + # Initialize result and named_parameters before recursing. + named_parameters = kwargs.get("named_parameters", None) + prefix = kwargs.get("prefix", "") + result = kwargs.get("result", {}) + + if named_parameters is None: + named_parameters = {n: p for n, p in model.named_parameters()} + else: + # A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters` + # of the submodule it belongs to. So while recursing we track the names that are not in the initial + # `named_parameters`. + for name, parameter in model.named_parameters(): + full_name = name if prefix == "" else f"{prefix}.{name}" + if full_name not in named_parameters: + # When we find one, it has to be one of the existing parameters. + for new_name, new_param in named_parameters.items(): + if new_param is parameter: + if new_name not in result: + result[new_name] = [] + result[new_name].append(full_name) + + # Once we have treated direct parameters, we move to the child modules. + for name, child in model.named_children(): + child_name = name if prefix == "" else f"{prefix}.{name}" + find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result) + + return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()]) + + +class FindTiedParametersResult(list): + """ + This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not + a list or on the `values` method as in the future this will be removed. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def values(self): + return sum([x[1:] for x in self], []) diff --git a/colossalai/quantization/bnb_config.py b/colossalai/quantization/bnb_config.py new file mode 100644 index 000000000..98a30211b --- /dev/null +++ b/colossalai/quantization/bnb_config.py @@ -0,0 +1,113 @@ +# adapted from Hugging Face accelerate/utils/dataclasses.py + +import warnings +from dataclasses import dataclass, field +from typing import List + +import torch + + +@dataclass +class BnbQuantizationConfig: + """ + A plugin to enable BitsAndBytes 4bit and 8bit quantization + """ + + load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."}) + + llm_int8_threshold: float = field( + default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"} + ) + + load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."}) + + bnb_4bit_quant_type: str = field( + default="fp4", + metadata={ + "help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}." + }, + ) + + bnb_4bit_use_double_quant: bool = field( + default=False, + metadata={ + "help": "enable nested quantization where the quantization constants from the first quantization are quantized again." + }, + ) + + bnb_4bit_compute_dtype: bool = field( + default="fp16", + metadata={ + "help": "This sets the computational type which might be different than the input time. For example, inputs might be " + "fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}." + }, + ) + + torch_dtype: torch.dtype = field( + default=None, + metadata={ + "help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value" + "to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model " + }, + ) + + skip_modules: List[str] = field( + default=None, + metadata={ + "help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`." + }, + ) + + keep_in_fp32_modules: List[str] = field( + default=None, + metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."}, + ) + + def __post_init__(self): + if isinstance(self.bnb_4bit_compute_dtype, str): + if self.bnb_4bit_compute_dtype == "fp32": + self.bnb_4bit_compute_dtype = torch.float32 + elif self.bnb_4bit_compute_dtype == "fp16": + self.bnb_4bit_compute_dtype = torch.float16 + elif self.bnb_4bit_compute_dtype == "bf16": + self.bnb_4bit_compute_dtype = torch.bfloat16 + else: + raise ValueError( + f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}" + ) + elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + if self.skip_modules is not None and not isinstance(self.skip_modules, list): + raise ValueError("skip_modules must be a list of strings") + + if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list): + raise ValueError("keep_in_fp_32_modules must be a list of strings") + + if self.load_in_4bit: + self.target_dtype = "int4" + + if self.load_in_8bit: + self.target_dtype = torch.int8 + + if self.load_in_4bit and self.llm_int8_threshold != 6.0: + warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit") + + if isinstance(self.torch_dtype, str): + if self.torch_dtype == "fp32": + self.torch_dtype = torch.float32 + elif self.torch_dtype == "fp16": + self.torch_dtype = torch.float16 + elif self.torch_dtype == "bf16": + self.torch_dtype = torch.bfloat16 + else: + raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}") + + if self.load_in_8bit and self.torch_dtype is None: + self.torch_dtype = torch.float16 + + if self.load_in_4bit and self.torch_dtype is None: + self.torch_dtype = self.bnb_4bit_compute_dtype + + if not isinstance(self.torch_dtype, torch.dtype): + raise ValueError("torch_dtype must be a torch.dtype") diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index cbcf72697..345dfde73 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -235,9 +235,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: - assert ( - param.dtype == self._dtype - ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False: + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" def _create_master_param_current_rank(self, param_list): # split each param evenly by world size diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 815b23fc7..8ab13c0ad 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -18,3 +18,4 @@ google protobuf transformers==4.36.2 peft>=0.7.1 +bitsandbytes>=0.39.0 diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py new file mode 100644 index 000000000..69febff38 --- /dev/null +++ b/tests/test_lora/test_lora.py @@ -0,0 +1,106 @@ +import copy +import os +from itertools import product + +import torch +from peft import LoraConfig +from torch import distributed as dist +from torch.optim import AdamW + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_checkpoint_io.utils import shared_tempdir + + +@clear_cache_before_run() +def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): + model = model_fn() + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + + test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] + test_configs = [ + { + "lora_config": lora_config, + "quantize": False, + }, + { + "lora_config": lora_config, + "quantize": True, + }, + ] + for plugin, test_config in product(test_plugins, test_configs): + # checkpoint loaded model + model_save = model_fn() + model_load = copy.deepcopy(model_save) + + optimizer = AdamW(model.parameters(), lr=0.001) + criterion = loss_fn + + booster = Booster(plugin=plugin) + model_save = booster.enable_lora(model_save, **test_config) + model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion) + + with shared_tempdir() as tempdir: + lora_ckpt_path = os.path.join(tempdir, "ckpt") + booster.save_lora_as_pretrained(model_save, lora_ckpt_path) + dist.barrier() + + # The Lora checkpoint should be small in size + checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) + assert checkpoint_size_mb < 1 + + model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config) + model_load, _, _, _, _ = booster.boost(model_load) + + check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) + + # test fwd bwd correctness + test_model = model_load + model_copy = copy.deepcopy(model_load) + + data = data_gen_fn() + data = { + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + } + + output = test_model(**data) + output = output_transform_fn(output) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + + for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()): + if "lora_" in n1: + # lora modules require gradients, thus updated + assert p1.requires_grad + assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) + else: + if not p1.requires_grad: + torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) + + +def run_lora_test(): + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + task_type = None + if name == "transformers_llama_for_casual_lm": + task_type = "CAUSAL_LM" + if name == "transformers_llama_for_sequence_classification": + task_type = "SEQ_CLS" + check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_lora_test() + + +@rerun_if_address_is_in_use() +def test_torch_ddp_lora(): + spawn(run_dist, 2) diff --git a/tests/test_lora/test_torch_ddp_lora.py b/tests/test_lora/test_torch_ddp_lora.py deleted file mode 100644 index b3169bf86..000000000 --- a/tests/test_lora/test_torch_ddp_lora.py +++ /dev/null @@ -1,108 +0,0 @@ -import copy -import os - -import torch -from peft import LoraConfig -from torch import distributed as dist -from torch.optim import AdamW - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin -from colossalai.testing import ( - assert_equal, - assert_not_equal, - check_state_dict_equal, - clear_cache_before_run, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_checkpoint_io.utils import shared_tempdir - - -@clear_cache_before_run() -def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): - model = model_fn() - lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - - plugin = TorchDDPPlugin() - booster = Booster(plugin=plugin) - - model = booster.enable_lora(model, lora_config=lora_config) - model_copy = copy.deepcopy(model) - - optimizer = AdamW(model.parameters(), lr=0.001) - criterion = loss_fn - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - data = data_gen_fn() - data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} - - output = model(**data) - output = output_transform_fn(output) - loss = criterion(output) - - booster.backward(loss, optimizer) - optimizer.clip_grad_by_norm(1.0) - optimizer.step() - - for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()): - if "lora_" in n1: - # lora modules require gradients, thus updated - assert p1.requires_grad - assert_not_equal(p1.to(p2.device), p2) - else: - if not p1.requires_grad: - assert_equal(p1.to(p2.device), p2) - - -@clear_cache_before_run() -def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): - plugin = TorchDDPPlugin() - lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - - model_save = model_fn() - model_load = copy.deepcopy(model_save) - - booster = Booster(plugin=plugin) - model_save = booster.enable_lora(model_save, lora_config=lora_config) - model_save, _, _, _, _ = booster.boost(model_save) - - with shared_tempdir() as tempdir: - lora_ckpt_path = os.path.join(tempdir, "ckpt") - booster.save_lora_as_pretrained(model_save, lora_ckpt_path) - dist.barrier() - - # The Lora checkpoint should be small in size - checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) - assert checkpoint_size_mb < 1 - - model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path) - model_load, _, _, _, _ = booster.boost(model_load) - - check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) - - -def run_lora_test(): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - task_type = None - if name == "transformers_llama_for_casual_lm": - task_type = "CAUSAL_LM" - if name == "transformers_llama_for_sequence_classification": - task_type = "SEQ_CLS" - check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) - check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_lora_test() - - -@rerun_if_address_is_in_use() -def test_torch_ddp_lora(): - spawn(run_dist, 2) From 7f8b16635b42013b73e1cb1ffdebc07b4d71ac93 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 29 Apr 2024 10:40:11 +0800 Subject: [PATCH 25/28] [misc] refactor launch API and tensor constructor (#5666) * [misc] remove config arg from initialize * [misc] remove old tensor contrusctor * [plugin] add npu support for ddp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [devops] fix doc test ci * [test] fix test launch * [doc] update launch doc --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/doc_test_on_pr.yml | 2 +- applications/Colossal-LLaMA/train.py | 2 +- .../ColossalChat/benchmarks/benchmark_ppo.py | 2 +- .../examples/training_scripts/train_dpo.py | 2 +- .../examples/training_scripts/train_ppo.py | 2 +- .../examples/training_scripts/train_rm.py | 2 +- .../examples/training_scripts/train_sft.py | 2 +- .../examples/dataset_evaluation/inference.py | 2 +- .../examples/gpt_evaluation/inference.py | 2 +- applications/ColossalMoE/infer.py | 8 ++- .../ColossalMoE/tests/test_mixtral_layer.py | 2 +- .../ColossalMoE/tests/test_moe_checkpoint.py | 2 +- applications/ColossalMoE/train.py | 8 +-- .../auto_parallel/offload/amp_optimizer.py | 2 +- .../offload/base_offload_module.py | 4 +- colossalai/booster/plugin/torch_ddp_plugin.py | 5 +- colossalai/inference/README.md | 2 +- colossalai/initialize.py | 16 +----- .../dynamic_batching/ray_dist_init.py | 2 +- .../legacy/inference/hybridengine/engine.py | 2 +- .../legacy/inference/pipeline/README.md | 34 ++++++------ .../inference/pipeline/benchmark/benchmark.py | 2 +- .../ray_serve/Colossal_Inference_rayserve.py | 2 +- .../torch_serve/Colossal_Inference_Handler.py | 2 +- colossalai/legacy/pipeline/rpc/utils.py | 2 +- colossalai/nn/optimizer/fused_adam.py | 4 +- colossalai/nn/optimizer/hybrid_adam.py | 4 +- colossalai/shardformer/README.md | 2 +- .../examples/convergence_benchmark.py | 2 +- .../examples/performance_benchmark.py | 3 +- colossalai/shardformer/shard/shardformer.py | 2 +- colossalai/tensor/d_tensor/README.md | 2 +- .../train_gpt_using_hybrid_parallelism.md | 2 +- .../train_vit_with_hybrid_parallelism.md | 2 +- docs/source/en/basics/booster_api.md | 2 +- docs/source/en/basics/launch_colossalai.md | 18 ++----- .../gradient_accumulation_with_booster.md | 2 +- .../gradient_clipping_with_booster.md | 2 +- docs/source/en/features/lazy_init.md | 2 +- .../mixed_precision_training_with_booster.md | 10 ++-- docs/source/en/features/nvme_offload.md | 2 +- docs/source/en/features/zero_with_chunk.md | 2 +- .../train_gpt_using_hybrid_parallelism.md | 2 +- .../train_vit_with_hybrid_parallelism.md | 2 +- docs/source/zh-Hans/basics/booster_api.md | 2 +- .../zh-Hans/basics/launch_colossalai.md | 18 ++----- .../gradient_accumulation_with_booster.md | 2 +- .../gradient_clipping_with_booster.md | 2 +- docs/source/zh-Hans/features/lazy_init.md | 2 +- .../mixed_precision_training_with_booster.md | 12 ++--- docs/source/zh-Hans/features/nvme_offload.md | 2 +- .../zh-Hans/features/zero_with_chunk.md | 2 +- .../roberta/pretraining/run_pretraining.py | 4 +- examples/images/dreambooth/debug.py | 2 +- .../dreambooth/train_dreambooth_colossalai.py | 4 +- .../train_dreambooth_colossalai_lora.py | 4 +- examples/images/resnet/train.py | 2 +- examples/images/vit/vit_benchmark.py | 2 +- examples/images/vit/vit_train_demo.py | 2 +- examples/inference/benchmark_llama.py | 2 +- examples/inference/run_llama_inference.py | 2 +- examples/language/bert/benchmark.py | 2 +- examples/language/bert/finetune.py | 2 +- .../auto_offload/train_gpt_offload.py | 3 +- .../auto_parallel/auto_parallel_with_gpt.py | 2 +- .../language/gpt/gemini/train_gpt_demo.py | 2 +- .../gpt/hybridparallelism/benchmark.py | 2 +- .../gpt/hybridparallelism/finetune.py | 2 +- examples/language/gpt/titans/train_gpt.py | 4 +- examples/language/grok-1/inference_tp.py | 2 +- examples/language/llama/benchmark.py | 2 +- .../openmoe/benchmark/benchmark_cai.py | 2 +- examples/language/openmoe/train.py | 2 +- examples/language/opt/opt_benchmark.py | 2 +- examples/language/opt/opt_train_demo.py | 2 +- examples/language/palm/train.py | 2 +- .../auto_parallel/auto_ckpt_batchsize_test.py | 2 +- .../auto_parallel/auto_ckpt_solver_test.py | 2 +- .../tutorial/new_api/cifar_resnet/train.py | 2 +- examples/tutorial/new_api/cifar_vit/train.py | 2 +- .../tutorial/new_api/glue_bert/finetune.py | 2 +- examples/tutorial/opt/opt/run_clm.py | 2 +- .../test_C_solver_consistency.py | 2 +- .../test_ckpt_torchvision.py | 4 +- .../test_offload/test_perf.py | 3 +- .../test_bias_addition_forward.py | 4 +- .../test_tensor_shard/test_checkpoint.py | 2 +- .../test_compatibility_with_ddp.py | 2 +- .../test_compatibility_with_gemini.py | 2 +- .../test_gpt/test_runtime_with_gpt_modules.py | 2 +- .../test_binary_elementwise_metainfo.py | 2 +- .../test_metainfo/test_conv_metainfo.py | 4 +- .../test_metainfo/test_linear_metainfo.py | 4 +- .../test_metainfo/test_norm_metainfo.py | 2 +- .../test_metainfo/test_pooling_metainfo.py | 4 +- .../test_node_handler/test_addbmm_handler.py | 4 +- .../test_node_handler/test_addmm_handler.py | 2 +- .../test_batch_norm_handler.py | 2 +- .../test_bias_linear_function_node.py | 2 +- .../test_bias_linear_module_node.py | 2 +- .../test_binary_elementwise_handler.py | 4 +- .../test_node_handler/test_bmm_handler.py | 4 +- .../test_node_handler/test_conv_handler.py | 4 +- .../test_embedding_handler.py | 4 +- .../test_node_handler/test_getitem_handler.py | 2 +- .../test_layer_norm_handler.py | 2 +- .../test_node_handler/test_linear_handler.py | 4 +- .../test_permute_and_transpose_handler.py | 2 +- .../test_node_handler/test_softmax_handler.py | 2 +- .../test_node_handler/test_split_handler.py | 2 +- .../test_node_handler/test_sum_handler.py | 2 +- .../test_node_handler/test_view_handler.py | 2 +- .../test_mixed_precision/test_fp16_torch.py | 2 +- .../test_plugin/test_3d_plugin.py | 2 +- .../test_plugin/test_dp_plugin_base.py | 2 +- .../test_plugin/test_gemini_plugin.py | 2 +- .../test_plugin/test_low_level_zero_plugin.py | 2 +- .../test_plugin/test_torch_ddp_plugin.py | 2 +- .../test_plugin/test_torch_fsdp_plugin.py | 2 +- .../test_gemini_checkpoint_io.py | 3 +- .../test_gemini_torch_compability.py | 3 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 3 +- .../test_low_level_zero_checkpoint_io.py | 2 +- .../test_plugins_huggingface_compatibility.py | 3 +- .../test_torch_ddp_checkpoint_io.py | 2 +- .../test_torch_fsdp_checkpoint_io.py | 2 +- .../test_cluster/test_device_mesh_manager.py | 2 +- tests/test_cluster/test_process_group_mesh.py | 54 ------------------- tests/test_device/test_alpha_beta.py | 2 +- tests/test_device/test_device_mesh.py | 2 +- tests/test_device/test_extract_alpha_beta.py | 2 +- tests/test_device/test_init_logical_pg.py | 2 +- .../test_search_logical_device_mesh.py | 2 +- .../test_activation_checkpoint_codegen.py | 4 +- ...st_nested_activation_checkpoint_codegen.py | 4 +- .../test_codegen/test_offload_codegen.py | 4 +- tests/test_fx/test_parallel_1d.py | 2 +- tests/test_infer/test_hybrid_bloom.py | 6 +-- tests/test_infer/test_hybrid_chatglm2.py | 6 +-- tests/test_infer/test_hybrid_llama.py | 6 +-- tests/test_legacy/test_amp/test_naive_fp16.py | 2 +- tests/test_legacy/test_amp/test_torch_fp16.py | 2 +- .../test_comm/test_boardcast_send_recv_v2.py | 2 +- tests/test_legacy/test_comm/test_comm.py | 2 +- .../test_comm/test_object_list_p2p.py | 2 +- .../test_comm/test_object_list_p2p_v2.py | 2 +- .../test_layers/test_1d/test_1d.py | 2 +- .../test_layers/test_2d/test_2d.py | 2 +- .../test_layers/test_2p5d/test_2p5d.py | 2 +- .../test_layers/test_3d/test_3d.py | 2 +- .../test_layers/test_cache_embedding.py | 2 +- .../test_tensor/core/test_dist_spec_mgr.py | 2 +- .../test_legacy/test_tensor/test_parameter.py | 2 +- .../test_trainer/test_pipeline/test_p2p.py | 2 +- .../test_pipeline/test_pipeline_schedule.py | 2 +- .../test_checkpoint/test_checkpoint_1d.py | 2 +- .../test_checkpoint/test_checkpoint_2d.py | 2 +- .../test_checkpoint/test_checkpoint_2p5d.py | 2 +- .../test_checkpoint/test_checkpoint_3d.py | 2 +- tests/test_legacy/test_utils/test_memory.py | 2 +- .../test_utils/test_norm_gradient_clipping.py | 2 +- tests/test_legacy/test_zero/test_commons.py | 2 +- tests/test_lora/test_lora.py | 3 +- tests/test_moe/test_grad_handler.py | 1 - tests/test_moe/test_kernel.py | 2 +- tests/test_moe/test_moe_ep_tp.py | 2 +- tests/test_moe/test_moe_group.py | 1 - tests/test_moe/test_moe_hybrid_zero.py | 2 +- tests/test_moe/test_moe_load_balance.py | 1 - tests/test_moe/test_moe_zero_fwd_bwd.py | 2 +- tests/test_moe/test_moe_zero_optim.py | 2 +- tests/test_optimizer/test_adam_kernel.py | 2 +- tests/test_pipeline/test_p2p_communication.py | 2 +- .../test_schedule/test_interleaved.py | 2 +- .../test_schedule/test_oneF_oneB.py | 2 +- tests/test_pipeline/test_stage_manager.py | 2 +- .../test_amp_optimizer.py | 4 +- .../test_naive_optimizer.py | 4 +- .../test_zero_optimizer.py | 4 +- .../test_layer/test_dist_crossentropy.py | 2 +- .../test_layer/test_dropout.py | 2 +- .../test_layer/test_embedding.py | 2 +- .../test_gpt2_qkv_fused_linear_1d.py | 2 +- .../test_layer/test_layernorm.py | 2 +- .../test_layer/test_linear_1d.py | 2 +- .../test_layer/test_qkv_fused_linear_1d.py | 2 +- .../test_layer/test_sequence_parallel.py | 2 +- .../test_vocab_parallel_embedding_1d.py | 2 +- .../test_model/test_shard_bert.py | 4 +- .../test_model/test_shard_blip2.py | 1 - .../test_model/test_shard_bloom.py | 4 +- .../test_model/test_shard_chatglm2.py | 2 - .../test_model/test_shard_falcon.py | 4 +- .../test_model/test_shard_gpt2.py | 2 - .../test_model/test_shard_llama.py | 4 +- .../test_model/test_shard_mistral.py | 2 +- .../test_model/test_shard_opt.py | 2 - .../test_model/test_shard_sam.py | 2 +- .../test_model/test_shard_t5.py | 2 - .../test_model/test_shard_vit.py | 4 +- .../test_model/test_shard_whisper.py | 4 +- tests/test_shardformer/test_with_torch_ddp.py | 2 +- tests/test_tensor/test_comm_spec_apply.py | 2 +- .../test_dtensor/test_comm_spec.py | 2 +- .../test_tensor/test_dtensor/test_dtensor.py | 2 +- .../test_dtensor/test_layout_converter.py | 6 +-- tests/test_tensor/test_mix_gather.py | 2 +- tests/test_tensor/test_padded_tensor.py | 2 +- .../test_shape_consistency_apply.py | 2 +- .../test_zero/test_gemini/test_chunk_mgrv2.py | 2 +- tests/test_zero/test_gemini/test_chunkv2.py | 2 +- tests/test_zero/test_gemini/test_fwd_bwd.py | 3 +- .../test_gemini/test_gemini_use_rmt.py | 3 +- .../test_zero/test_gemini/test_grad_accum.py | 3 +- tests/test_zero/test_gemini/test_grad_clip.py | 3 +- tests/test_zero/test_gemini/test_inference.py | 3 +- tests/test_zero/test_gemini/test_optim.py | 3 +- tests/test_zero/test_gemini/test_search.py | 2 +- .../test_gemini/test_zeroddp_state_dict.py | 3 +- .../test_gemini/test_zerooptim_state_dict.py | 3 +- .../test_zero/test_low_level/test_grad_acc.py | 2 +- .../test_zero/test_low_level/test_zero1_2.py | 2 +- .../test_low_level/test_zero_ckpt.py | 2 +- 223 files changed, 294 insertions(+), 403 deletions(-) diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml index 8afc46b87..27f7e76af 100644 --- a/.github/workflows/doc_test_on_pr.yml +++ b/.github/workflows/doc_test_on_pr.yml @@ -56,7 +56,7 @@ jobs: needs: detect-changed-doc runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm timeout-minutes: 20 defaults: diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index 37e4fcc80..43a360a9a 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -136,7 +136,7 @@ def main() -> None: # ============================== # Initialize Distributed Training # ============================== - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() accelerator = get_accelerator() coordinator = DistCoordinator() diff --git a/applications/ColossalChat/benchmarks/benchmark_ppo.py b/applications/ColossalChat/benchmarks/benchmark_ppo.py index e1b7a313f..00edf0534 100644 --- a/applications/ColossalChat/benchmarks/benchmark_ppo.py +++ b/applications/ColossalChat/benchmarks/benchmark_ppo.py @@ -66,7 +66,7 @@ def benchmark_train(args): # ============================== # Initialize Distributed Training # ============================== - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() coordinator = DistCoordinator() # ====================================================== diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index b9287eb1a..f06c23a9f 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -37,7 +37,7 @@ def train(args): # ============================== # Initialize Distributed Training # ============================== - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() coordinator = DistCoordinator() # ============================== diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py index 7c91fa347..727cff7ca 100755 --- a/applications/ColossalChat/examples/training_scripts/train_ppo.py +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py @@ -39,7 +39,7 @@ def train(args): # ============================== # Initialize Distributed Training # ============================== - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() coordinator = DistCoordinator() # ====================================================== diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py index a0c710f2b..364198c1d 100755 --- a/applications/ColossalChat/examples/training_scripts/train_rm.py +++ b/applications/ColossalChat/examples/training_scripts/train_rm.py @@ -34,7 +34,7 @@ def train(args): # ============================== # Initialize Distributed Training # ============================== - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() coordinator = DistCoordinator() # ====================================================== diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index fcd1a429c..ae20f2abc 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -29,7 +29,7 @@ def train(args): # ============================== # Initialize Distributed Training # ============================== - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() coordinator = DistCoordinator() # ============================== diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 13bbb12b6..a7307635d 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -81,7 +81,7 @@ def rm_and_merge( def main(args): - colossalai.launch_from_torch(config={}, seed=42) + colossalai.launch_from_torch(seed=42) accelerator = get_accelerator() world_size = dist.get_world_size() diff --git a/applications/ColossalEval/examples/gpt_evaluation/inference.py b/applications/ColossalEval/examples/gpt_evaluation/inference.py index 5b09f9de8..408ba3e7b 100644 --- a/applications/ColossalEval/examples/gpt_evaluation/inference.py +++ b/applications/ColossalEval/examples/gpt_evaluation/inference.py @@ -81,7 +81,7 @@ def rm_and_merge( def main(args): - colossalai.launch_from_torch(config={}, seed=42) + colossalai.launch_from_torch(seed=42) world_size = dist.get_world_size() rank = dist.get_rank() diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py index c175fe9e3..543c434d2 100644 --- a/applications/ColossalMoE/infer.py +++ b/applications/ColossalMoE/infer.py @@ -57,7 +57,7 @@ def main(): args = parse_args() # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) coordinator = DistCoordinator() config = MixtralConfig.from_pretrained(args.model_name) @@ -96,7 +96,11 @@ def main(): if coordinator.rank == 0: text = ["Hello my name is"] else: - text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"] + text = [ + "What's the largest country in the world?", + "How many people live in China?", + "帮我续写这首诗:离离原上草", + ] tokenizer.pad_token = tokenizer.unk_token inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device()) diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/applications/ColossalMoE/tests/test_mixtral_layer.py index 57589ab20..cbb70f195 100644 --- a/applications/ColossalMoE/tests/test_mixtral_layer.py +++ b/applications/ColossalMoE/tests/test_mixtral_layer.py @@ -50,7 +50,7 @@ def check_mixtral_moe_layer(): def run_dist(rank: int, world_size: int, port: int): - colossalai.launch({}, rank, world_size, "localhost", port) + colossalai.launch(rank, world_size, "localhost", port) check_mixtral_moe_layer() diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py index 822e7410f..074dbf835 100644 --- a/applications/ColossalMoE/tests/test_moe_checkpoint.py +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -133,7 +133,7 @@ def check_mixtral_moe_layer(): def run_dist(rank: int, world_size: int, port: int): - colossalai.launch({}, rank, world_size, "localhost", port) + colossalai.launch(rank, world_size, "localhost", port) check_mixtral_moe_layer() diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index 850236726..d2789d644 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -145,7 +145,7 @@ def main(): args = parse_args() # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) coordinator = DistCoordinator() # Set plugin @@ -195,9 +195,9 @@ def main(): lr_scheduler = CosineAnnealingWarmupLR( optimizer=optimizer, total_steps=args.num_epochs * len(dataloader), - warmup_steps=args.warmup_steps - if args.warmup_steps is not None - else int(args.num_epochs * len(dataloader) * 0.025), + warmup_steps=( + args.warmup_steps if args.warmup_steps is not None else int(args.num_epochs * len(dataloader) * 0.025) + ), eta_min=0.1 * args.lr, ) diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index fe8439269..ab02de7ce 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -126,7 +126,7 @@ class AMPOptimizer(OptimizerWrapper): return self.grad_scaler.scale.item() def zero_grad(self, *args, **kwargs): - self.module.overflow_counter = torch.cuda.IntTensor([0]) + self.module.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) return self.optim.zero_grad(set_to_none=True) def step(self, *args, **kwargs): diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index 60de7743a..8afd29e43 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -4,7 +4,7 @@ from typing import Optional, Set import torch import torch.nn as nn -from colossalai.utils import _cast_float +from colossalai.utils import _cast_float, get_current_device from colossalai.utils.common import free_storage from .region_manager import RegionManager @@ -25,7 +25,7 @@ class BaseOffloadModule: self.model = model self.region_manager = region_manager self.grad_hook_list = [] - self.overflow_counter = torch.cuda.IntTensor([0]) + self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_current_device()) self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 482cc4e98..5116446a4 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -10,6 +10,7 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.utils import get_current_device from .dp_plugin_base import DPPluginBase @@ -203,7 +204,7 @@ class TorchDDPPlugin(DPPluginBase): return True def supported_devices(self) -> List[str]: - return ["cuda"] + return ["cuda", "npu"] def configure( self, @@ -214,7 +215,7 @@ class TorchDDPPlugin(DPPluginBase): lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: # cast model to cuda - model = model.cuda() + model = model.to(get_current_device()) # convert model to sync bn model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index c2b808155..0bdaf347d 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -114,7 +114,7 @@ import colossalai from transformers import LlamaForCausalLM, LlamaTokenizer #launch distributed environment -colossalai.launch_from_torch(config={}) +colossalai.launch_from_torch() # load original model and tokenizer model = LlamaForCausalLM.from_pretrained("/path/to/model") diff --git a/colossalai/initialize.py b/colossalai/initialize.py index aaeaad382..934555e19 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -2,20 +2,15 @@ # -*- encoding: utf-8 -*- import os -import warnings -from pathlib import Path -from typing import Dict, Union import torch.distributed as dist from colossalai.accelerator import get_accelerator -from colossalai.context import Config from colossalai.logging import get_dist_logger from colossalai.utils import set_seed def launch( - config: Union[str, Path, Config, Dict], rank: int, world_size: int, host: str, @@ -44,8 +39,6 @@ def launch( Raises: Exception: Raise exception when config type is wrong """ - if rank == 0: - warnings.warn("`config` is deprecated and will be removed soon.") cur_accelerator = get_accelerator() @@ -68,7 +61,6 @@ def launch( def launch_from_slurm( - config: Union[str, Path, Config, Dict], host: str, port: int, backend: str = "nccl", @@ -95,7 +87,6 @@ def launch_from_slurm( ) launch( - config=config, rank=rank, world_size=world_size, host=host, @@ -107,7 +98,6 @@ def launch_from_slurm( def launch_from_openmpi( - config: Union[str, Path, Config, Dict], host: str, port: int, backend: str = "nccl", @@ -135,7 +125,6 @@ def launch_from_openmpi( ) launch( - config=config, local_rank=local_rank, rank=rank, world_size=world_size, @@ -147,9 +136,7 @@ def launch_from_openmpi( ) -def launch_from_torch( - config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True -): +def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = True): """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size from the environment variables set by PyTorch @@ -171,7 +158,6 @@ def launch_from_torch( ) launch( - config=config, local_rank=local_rank, rank=rank, world_size=world_size, diff --git a/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py b/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py index 3e40bb0ee..7a74fb949 100644 --- a/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py @@ -56,7 +56,7 @@ class Worker: # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully collective.init_collective_group(world_size, rank, "nccl", "default") # initialize and set distributed environment - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") log_cuda_info("Worker.setup") diff --git a/colossalai/legacy/inference/hybridengine/engine.py b/colossalai/legacy/inference/hybridengine/engine.py index bc4e4fd19..019a678ce 100644 --- a/colossalai/legacy/inference/hybridengine/engine.py +++ b/colossalai/legacy/inference/hybridengine/engine.py @@ -42,7 +42,7 @@ class CaiInferEngine: import colossalai from transformers import LlamaForCausalLM, LlamaTokenizer - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() model = LlamaForCausalLM.from_pretrained("your_path_to_model") tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") diff --git a/colossalai/legacy/inference/pipeline/README.md b/colossalai/legacy/inference/pipeline/README.md index f9bb35cc4..cbe96fff0 100644 --- a/colossalai/legacy/inference/pipeline/README.md +++ b/colossalai/legacy/inference/pipeline/README.md @@ -36,7 +36,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy import colossalai from transformers import LlamaForCausalLM, LlamaTokenizer -colossalai.launch_from_torch(config={}) +colossalai.launch_from_torch() model = LlamaForCausalLM.from_pretrained("/path/to/model") tokenizer = LlamaTokenizer.from_pretrained("/path/to/model") @@ -57,27 +57,27 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t ### Llama Throughput (tokens/s) | input length=1024, output length=128 #### A10 7b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| -| :---: | :---: | :---: | :---: | :---: | :---: | :---:| -| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM | -| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) | +|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:| +| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM | +| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM | #### A10 13b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | -| :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | -| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) | +|:----------------------------:|:-----:|:-----:|:-----:|:-----:| +| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | +| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | #### A800 7b, fp16 -| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | -| :---: | :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | -| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +|:----------------------------:|:-----:|:------:|:------:|:------:|:------:| +| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | +| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | #### A800 13b, fp16 -| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | -| :---: | :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 | -| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:| +| Pipeline Inference | 41.78 | 94.18 | 172.67 | 310.75 | 470.15 | +| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 | diff --git a/colossalai/legacy/inference/pipeline/benchmark/benchmark.py b/colossalai/legacy/inference/pipeline/benchmark/benchmark.py index 8392d0a1e..7bb89f4f4 100644 --- a/colossalai/legacy/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/legacy/inference/pipeline/benchmark/benchmark.py @@ -12,7 +12,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 -colossalai.launch_from_torch(config={}) +colossalai.launch_from_torch() def data_gen(batch_size: int = 4, seq_len: int = 512): diff --git a/colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py b/colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py index d758b467c..37e7bae41 100644 --- a/colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py +++ b/colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py @@ -56,7 +56,7 @@ class Worker: # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully collective.init_collective_group(world_size, rank, "nccl", "default") # initialize and set distributed environment - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") log_cuda_info("Worker.setup") diff --git a/colossalai/legacy/inference/serving/torch_serve/Colossal_Inference_Handler.py b/colossalai/legacy/inference/serving/torch_serve/Colossal_Inference_Handler.py index e07494b8a..bcbdee951 100644 --- a/colossalai/legacy/inference/serving/torch_serve/Colossal_Inference_Handler.py +++ b/colossalai/legacy/inference/serving/torch_serve/Colossal_Inference_Handler.py @@ -98,7 +98,7 @@ class ColossalInferenceHandler(BaseHandler, ABC): self.model.cuda() self.model.eval() - colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host=host, port=port, backend="nccl") logger.info("Initializing TPInferEngine ...") shard_config = ShardConfig( enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True} diff --git a/colossalai/legacy/pipeline/rpc/utils.py b/colossalai/legacy/pipeline/rpc/utils.py index 808de301a..87060ab8a 100644 --- a/colossalai/legacy/pipeline/rpc/utils.py +++ b/colossalai/legacy/pipeline/rpc/utils.py @@ -114,7 +114,7 @@ def run_worker(rank, args, master_func): port = args.master_port backend = "nccl" if device == "cuda" else "gloo" - launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + launch(rank, world_size, host, int(port), backend, verbose=False) ppg.set_global_info( rank=rank, world_size=world_size, diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index aeb5cc91b..c12551657 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -8,7 +8,7 @@ Licensed under the MIT License. """ import torch -from colossalai.utils import multi_tensor_applier +from colossalai.utils import get_current_device, multi_tensor_applier class FusedAdam(torch.optim.Optimizer): @@ -75,7 +75,7 @@ class FusedAdam(torch.optim.Optimizer): fused_optim = FusedOptimizerLoader().load() # Skip buffer - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device()) self.multi_tensor_adam = fused_optim.multi_tensor_adam else: raise RuntimeError("FusedAdam requires cuda extensions") diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index c9c1f81bf..417881a0b 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -3,7 +3,7 @@ from typing import Any, Optional import torch from colossalai.kernel.kernel_loader import FusedOptimizerLoader -from colossalai.utils import multi_tensor_applier +from colossalai.utils import get_current_device, multi_tensor_applier from .cpu_adam import CPUAdam @@ -87,7 +87,7 @@ class HybridAdam(CPUAdam): if torch.cuda.is_available(): fused_optim = FusedOptimizerLoader().load() self.gpu_adam_op = fused_optim.multi_tensor_adam - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device()) @torch.no_grad() def step(self, closure=None, div_scale: float = -1): diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index d45421868..47ef98ccf 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -38,7 +38,7 @@ from transformers import BertForMaskedLM import colossalai # launch colossalai -colossalai.launch_from_torch(config={}) +colossalai.launch_from_torch() # create model config = BertConfig.from_pretrained('bert-base-uncased') diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py index b03e6201d..4caf61eb4 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.py +++ b/colossalai/shardformer/examples/convergence_benchmark.py @@ -28,7 +28,7 @@ def to_device(x: Any, device: torch.device) -> Any: def train(args): - colossalai.launch_from_torch(config={}, seed=42) + colossalai.launch_from_torch(seed=42) coordinator = DistCoordinator() # prepare for data and dataset diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py index 81215dcdf..cce8b6f3a 100644 --- a/colossalai/shardformer/examples/performance_benchmark.py +++ b/colossalai/shardformer/examples/performance_benchmark.py @@ -1,6 +1,7 @@ """ Shardformer Benchmark """ + import torch import torch.distributed as dist import transformers @@ -84,5 +85,5 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d # start benchmark, command: # torchrun --standalone --nproc_per_node=2 performance_benchmark.py if __name__ == "__main__": - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0) diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index b132f47fd..b3991c4f0 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -26,7 +26,7 @@ class ShardFormer: import colossalai import torch - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') shard_config = ShardConfig() diff --git a/colossalai/tensor/d_tensor/README.md b/colossalai/tensor/d_tensor/README.md index 3d862dddb..367db5ccd 100644 --- a/colossalai/tensor/d_tensor/README.md +++ b/colossalai/tensor/d_tensor/README.md @@ -69,7 +69,7 @@ import colossalai from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.d_tensor import DTensor, ShardingSpec -colossalai.launch_from_torch(config={}) +colossalai.launch_from_torch() # define your device mesh # assume you have 4 GPUs diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 0133dfd86..b27f9c811 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -75,7 +75,7 @@ WARMUP_FRACTION = 0.1 we create a distributed environment. ```python # Launch ColossalAI -colossalai.launch_from_torch(config={}, seed=42) +colossalai.launch_from_torch( seed=42) coordinator = DistCoordinator() ``` prepare the dataset. You can use `plugin.prepare_dataloader` to generate a dataloader or customize your own dataloader. diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md index dfc2cd596..ac4169344 100644 --- a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -71,7 +71,7 @@ PP_SIZE = 2 Create a distributed environment. ```python # Launch ColossalAI -colossalai.launch_from_torch(config={}, seed=SEEDå) +colossalai.launch_from_torch( seed=SEEDå) coordinator = DistCoordinator() world_size = coordinator.world_size ``` diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md index 2c75dd9ac..a33be3b49 100644 --- a/docs/source/en/basics/booster_api.md +++ b/docs/source/en/basics/booster_api.md @@ -55,7 +55,7 @@ from colossalai.booster.plugin import TorchDDPPlugin def train(): # launch colossalai - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(rank=rank, world_size=world_size, port=port, host='localhost') # create plugin and objects for training plugin = TorchDDPPlugin() diff --git a/docs/source/en/basics/launch_colossalai.md b/docs/source/en/basics/launch_colossalai.md index 334757ea7..8a6028d6c 100644 --- a/docs/source/en/basics/launch_colossalai.md +++ b/docs/source/en/basics/launch_colossalai.md @@ -87,8 +87,7 @@ import colossalai args = colossalai.get_default_parser().parse_args() # launch distributed environment -colossalai.launch(config=args.config, - rank=args.rank, +colossalai.launch(rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, @@ -106,20 +105,11 @@ First, we need to set the launch method in our code. As this is a wrapper of the use `colossalai.launch_from_torch`. The arguments required for distributed environment such as rank, world size, host and port are all set by the PyTorch launcher and can be read from the environment variable directly. -config.py -```python -BATCH_SIZE = 512 -LEARNING_RATE = 3e-3 -WEIGHT_DECAY = 0.3 -NUM_EPOCHS = 2 -``` train.py ```python import colossalai -colossalai.launch_from_torch( - config="./config.py", -) +colossalai.launch_from_torch() ... ``` @@ -203,7 +193,6 @@ Do this in your training script: import colossalai colossalai.launch_from_slurm( - config=, host=args.host, port=args.port ) @@ -224,7 +213,6 @@ use them to start the distributed backend. Do this in your train.py: ```python colossalai.launch_from_openmpi( - config=, host=args.host, port=args.port ) @@ -238,3 +226,5 @@ mpirun --hostfile -np python train.py --host diff --git a/docs/source/en/features/gradient_accumulation_with_booster.md b/docs/source/en/features/gradient_accumulation_with_booster.md index ea97dd92e..f1e47e9bb 100644 --- a/docs/source/en/features/gradient_accumulation_with_booster.md +++ b/docs/source/en/features/gradient_accumulation_with_booster.md @@ -45,7 +45,7 @@ We then need to initialize distributed environment. For demo purpose, we uses `l parser = colossalai.get_default_parser() args = parser.parse_args() # launch from torch -colossalai.launch_from_torch(config=dict()) +colossalai.launch_from_torch() ``` ### Step 3. Create training components diff --git a/docs/source/en/features/gradient_clipping_with_booster.md b/docs/source/en/features/gradient_clipping_with_booster.md index 14eee67bc..9f9074e1d 100644 --- a/docs/source/en/features/gradient_clipping_with_booster.md +++ b/docs/source/en/features/gradient_clipping_with_booster.md @@ -61,7 +61,7 @@ We then need to initialize distributed environment. For demo purpose, we uses `l for other initialization methods. ```python -colossalai.launch_from_torch(config=dict()) +colossalai.launch_from_torch() logger = get_dist_logger() ``` diff --git a/docs/source/en/features/lazy_init.md b/docs/source/en/features/lazy_init.md index 160f68767..30b33b52f 100644 --- a/docs/source/en/features/lazy_init.md +++ b/docs/source/en/features/lazy_init.md @@ -29,7 +29,7 @@ from colossalai.booster.plugin import GeminiPlugin from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining -colossalai.launch({}) +colossalai.launch() plugin = GeminiPlugin() booster = Booster(plugin) diff --git a/docs/source/en/features/mixed_precision_training_with_booster.md b/docs/source/en/features/mixed_precision_training_with_booster.md index 8e702a578..baaaacddd 100644 --- a/docs/source/en/features/mixed_precision_training_with_booster.md +++ b/docs/source/en/features/mixed_precision_training_with_booster.md @@ -20,10 +20,10 @@ In Colossal-AI, we have incorporated different implementations of mixed precisio 3. naive amp | Colossal-AI | support tensor parallel | support pipeline parallel | fp16 extent | -| -------------- | ----------------------- | ------------------------- | ---------------------------------------------------------------------------------------------------- | -| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation | -| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 | -| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 | +|----------------|-------------------------|---------------------------|------------------------------------------------------------------------------------------------------| +| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation | +| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 | +| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 | The first two rely on the original implementation of PyTorch (version 1.6 and above) and NVIDIA Apex. The last method is similar to Apex O2 level. @@ -164,7 +164,7 @@ parser = colossalai.get_default_parser() args = parser.parse_args() # launch from torch -colossalai.launch_from_torch(config=dict()) +colossalai.launch_from_torch() ``` diff --git a/docs/source/en/features/nvme_offload.md b/docs/source/en/features/nvme_offload.md index 6ed6f2dee..343a1f67e 100644 --- a/docs/source/en/features/nvme_offload.md +++ b/docs/source/en/features/nvme_offload.md @@ -185,7 +185,7 @@ Then we can train GPT model with Gemini. The placement policy of Gemini should b ```python def train_gemini_cpu(nvme_offload_fraction: float = 0.0): - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() config = GPT2Config() with ColoInitContext(device=torch.cuda.current_device()): model = GPT2LMHeadModel(config) diff --git a/docs/source/en/features/zero_with_chunk.md b/docs/source/en/features/zero_with_chunk.md index 62be86488..f0c13830a 100644 --- a/docs/source/en/features/zero_with_chunk.md +++ b/docs/source/en/features/zero_with_chunk.md @@ -174,7 +174,7 @@ def main(): SEQ_LEN = 1024 VOCAB_SIZE = 50257 NUM_STEPS = 10 - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() # build criterion criterion = GPTLMLoss() diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index cf7d19172..4d4ea8163 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -62,7 +62,7 @@ plugin = HybridParallelPlugin( ## 创建分布式环境. ```python # Launch ColossalAI -colossalai.launch_from_torch(config={}, seed=42) +colossalai.launch_from_torch(seed=42) coordinator = DistCoordinator() ``` ## 定义GPT-2模型的训练组件 diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md index f32f6c367..c234a3c6e 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -70,7 +70,7 @@ PP_SIZE = 2 首先我们创建一个分布式环境 ```python # Launch ColossalAI -colossalai.launch_from_torch(config={}, seed=SEEDå) +colossalai.launch_from_torch(seed=SEEDå) coordinator = DistCoordinator() world_size = coordinator.world_size ``` diff --git a/docs/source/zh-Hans/basics/booster_api.md b/docs/source/zh-Hans/basics/booster_api.md index bb100964d..a9357617d 100644 --- a/docs/source/zh-Hans/basics/booster_api.md +++ b/docs/source/zh-Hans/basics/booster_api.md @@ -60,7 +60,7 @@ from colossalai.booster.plugin import TorchDDPPlugin def train(): # launch colossalai - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(rank=rank, world_size=world_size, port=port, host='localhost') # create plugin and objects for training plugin = TorchDDPPlugin() diff --git a/docs/source/zh-Hans/basics/launch_colossalai.md b/docs/source/zh-Hans/basics/launch_colossalai.md index 39b09deae..a80d16717 100644 --- a/docs/source/zh-Hans/basics/launch_colossalai.md +++ b/docs/source/zh-Hans/basics/launch_colossalai.md @@ -74,8 +74,7 @@ import colossalai args = colossalai.get_default_parser().parse_args() # launch distributed environment -colossalai.launch(config=args.config, - rank=args.rank, +colossalai.launch(rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, @@ -93,20 +92,11 @@ PyTorch自带的启动器需要在每个节点上都启动命令才能启动多 首先,我们需要在代码里指定我们的启动方式。由于这个启动器是PyTorch启动器的封装,那么我们自然而然应该使用`colossalai.launch_from_torch`。 分布式环境所需的参数,如 rank, world size, host 和 port 都是由 PyTorch 启动器设置的,可以直接从环境变量中读取。 -config.py -```python -BATCH_SIZE = 512 -LEARNING_RATE = 3e-3 -WEIGHT_DECAY = 0.3 -NUM_EPOCHS = 2 -``` train.py ```python import colossalai -colossalai.launch_from_torch( - config="./config.py", -) +colossalai.launch_from_torch() ... ``` @@ -186,7 +176,6 @@ colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --e import colossalai colossalai.launch_from_slurm( - config=, host=args.host, port=args.port ) @@ -206,7 +195,6 @@ srun python train.py --host --port 29500 您可以在您的训练脚本中尝试以下操作。 ```python colossalai.launch_from_openmpi( - config=, host=args.host, port=args.port ) @@ -219,3 +207,5 @@ mpirun --hostfile -np python train.py --host diff --git a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md index 824308f94..7ad8fb145 100644 --- a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md +++ b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md @@ -46,7 +46,7 @@ parser = colossalai.get_default_parser() args = parser.parse_args() # launch from torch -colossalai.launch_from_torch(config=dict()) +colossalai.launch_from_torch() ``` diff --git a/docs/source/zh-Hans/features/gradient_clipping_with_booster.md b/docs/source/zh-Hans/features/gradient_clipping_with_booster.md index fdec09bf1..b000d4585 100644 --- a/docs/source/zh-Hans/features/gradient_clipping_with_booster.md +++ b/docs/source/zh-Hans/features/gradient_clipping_with_booster.md @@ -61,7 +61,7 @@ from colossalai.nn.lr_scheduler import CosineAnnealingLR 我们需要初始化分布式环境. 为了快速演示,我们使用`launch_from_torch`. 您可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md) ```python -colossalai.launch_from_torch(config=dict()) +colossalai.launch_from_torch() logger = get_dist_logger() ``` diff --git a/docs/source/zh-Hans/features/lazy_init.md b/docs/source/zh-Hans/features/lazy_init.md index 137719c69..c9cc0e4ba 100644 --- a/docs/source/zh-Hans/features/lazy_init.md +++ b/docs/source/zh-Hans/features/lazy_init.md @@ -29,7 +29,7 @@ from colossalai.booster.plugin import GeminiPlugin from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining -colossalai.launch({}) +colossalai.launch() plugin = GeminiPlugin() booster = Booster(plugin) diff --git a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md index 8e9f614a2..53d9013db 100644 --- a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md +++ b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md @@ -19,11 +19,11 @@ AMP 代表自动混合精度训练。 2. apex.amp 3. naive amp -| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16 范围 | -| -------------- | ------------ | ------------ | --------------------------------------------------------- | -| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至 fp16 | -| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 | -| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至 fp16 | +| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16 范围 | +|----------------|--------------|--------------|-------------------------------------------------------| +| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至 fp16 | +| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 | +| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至 fp16 | 前两个依赖于 PyTorch (1.6 及以上) 和 NVIDIA Apex 的原始实现。最后一种方法类似 Apex O2。在这些方法中,Apex-AMP 与张量并行不兼容。这是因为张量是以张量并行的方式在设备之间拆分的,因此,需要在不同的进程之间进行通信,以检查整个模型权重中是否出现 inf 或 nan。我们修改了 torch amp 实现,使其现在与张量并行兼容。 @@ -153,7 +153,7 @@ parser = colossalai.get_default_parser() args = parser.parse_args() # launch from torch -colossalai.launch_from_torch(config=dict()) +colossalai.launch_from_torch() ``` diff --git a/docs/source/zh-Hans/features/nvme_offload.md b/docs/source/zh-Hans/features/nvme_offload.md index 1feb9dde5..f013e755d 100644 --- a/docs/source/zh-Hans/features/nvme_offload.md +++ b/docs/source/zh-Hans/features/nvme_offload.md @@ -175,7 +175,7 @@ Mem usage: 4968.016 MB ```python def train_gemini_cpu(nvme_offload_fraction: float = 0.0): - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() config = GPT2Config() with ColoInitContext(device=torch.cuda.current_device()): model = GPT2LMHeadModel(config) diff --git a/docs/source/zh-Hans/features/zero_with_chunk.md b/docs/source/zh-Hans/features/zero_with_chunk.md index c4f21c73c..4a4655d60 100644 --- a/docs/source/zh-Hans/features/zero_with_chunk.md +++ b/docs/source/zh-Hans/features/zero_with_chunk.md @@ -174,7 +174,7 @@ def main(): SEQ_LEN = 1024 VOCAB_SIZE = 50257 NUM_STEPS = 10 - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() # build criterion criterion = GPTLMLoss() diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index 40b11d649..48cde8239 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -35,12 +35,12 @@ def main(): if args.vscode_debug: colossalai.launch( - config={}, rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend + rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend ) args.local_rank = -1 args.log_interval = 1 else: - colossalai.launch_from_torch(config={}) # args.colossal_config + colossalai.launch_from_torch() # args.colossal_config args.local_rank = int(os.environ["LOCAL_RANK"]) logger.info( f"launch_from_torch, world size: {torch.distributed.get_world_size()} | " diff --git a/examples/images/dreambooth/debug.py b/examples/images/dreambooth/debug.py index 8ce4dc3bb..64588e904 100644 --- a/examples/images/dreambooth/debug.py +++ b/examples/images/dreambooth/debug.py @@ -9,7 +9,7 @@ from colossalai.zero import ColoInitContext path = "/data/scratch/diffuser/stable-diffusion-v1-4" -colossalai.launch_from_torch(config={}) +colossalai.launch_from_torch() with ColoInitContext(device="cpu"): vae = AutoencoderKL.from_pretrained( path, diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index cc2b2ebc7..2bacb3a04 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -372,9 +372,9 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(args): if args.seed is None: - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() else: - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) local_rank = dist.get_rank() world_size = dist.get_world_size() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 227488abe..c4ef2a34e 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -371,9 +371,9 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(args): if args.seed is None: - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() else: - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) local_rank = gpc.get_local_rank(ParallelMode.DATA) world_size = gpc.get_world_size(ParallelMode.DATA) diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py index 5871bbf87..a53a85180 100644 --- a/examples/images/resnet/train.py +++ b/examples/images/resnet/train.py @@ -128,7 +128,7 @@ def main(): # ============================== # Launch Distributed Environment # ============================== - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() coordinator = DistCoordinator() # update the learning rate with linear scaling diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index fdae9ee01..790bb2b74 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -46,7 +46,7 @@ def main(): args = parse_benchmark_args() # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) coordinator = DistCoordinator() world_size = coordinator.world_size diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 81009b370..a65f89171 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -137,7 +137,7 @@ def main(): args = parse_demo_args() # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) coordinator = DistCoordinator() world_size = coordinator.world_size diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 26cac977a..a23ab500a 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -136,7 +136,7 @@ def benchmark_inference(args): def hybrid_inference(rank, world_size, port, args): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") benchmark_inference(args) diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py index b5228c64e..a4e6fd0a1 100644 --- a/examples/inference/run_llama_inference.py +++ b/examples/inference/run_llama_inference.py @@ -68,7 +68,7 @@ def run_inference(args): def run_tp_pipeline_inference(rank, world_size, port, args): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_inference(args) diff --git a/examples/language/bert/benchmark.py b/examples/language/bert/benchmark.py index 10bd367fd..9270c1b0c 100644 --- a/examples/language/bert/benchmark.py +++ b/examples/language/bert/benchmark.py @@ -81,7 +81,7 @@ def main(): # ============================== # Launch Distributed Environment # ============================== - colossalai.launch_from_torch(config={}, seed=42) + colossalai.launch_from_torch(seed=42) coordinator = DistCoordinator() # local_batch_size = BATCH_SIZE // coordinator.world_size diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index bd6c393a7..7e8c07fdc 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -202,7 +202,7 @@ def main(): # ============================== # Launch Distributed Environment # ============================== - colossalai.launch_from_torch(config={}, seed=42) + colossalai.launch_from_torch(seed=42) coordinator = DistCoordinator() lr = LEARNING_RATE * coordinator.world_size diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index b35112498..fbb3a151a 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -94,8 +94,7 @@ def train_gpt(args): def run(rank, world_size, port, args): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") train_gpt(args) diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py index f3d35dd90..9a33c6598 100644 --- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -47,7 +47,7 @@ def get_data(batch_size, seq_len, vocab_size): def main(): disable_existing_loggers() - launch_from_torch(config={}) + launch_from_torch() logger = get_dist_logger() config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) if FP16: diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 78d090ba2..4911ff124 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -132,7 +132,7 @@ def main(): PROF_FLAG = False # The flag of profiling, False by default disable_existing_loggers() - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() logger = get_dist_logger() logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]) diff --git a/examples/language/gpt/hybridparallelism/benchmark.py b/examples/language/gpt/hybridparallelism/benchmark.py index 1315deae6..8c236b524 100644 --- a/examples/language/gpt/hybridparallelism/benchmark.py +++ b/examples/language/gpt/hybridparallelism/benchmark.py @@ -67,7 +67,7 @@ def main(): parser.add_argument("--cpu_offload", action="store_true", help="Use gradient checkpointing") args = parser.parse_args() - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() coordinator = DistCoordinator() def empty_init(): diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 888f47aaa..32b2dfcc0 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -196,7 +196,7 @@ def main(): # ============================== # Launch Distributed Environment # ============================== - colossalai.launch_from_torch(config={}, seed=42) + colossalai.launch_from_torch(seed=42) coordinator = DistCoordinator() # local_batch_size = BATCH_SIZE // coordinator.world_size diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py index 565cf1e01..6b45bd33e 100644 --- a/examples/language/gpt/titans/train_gpt.py +++ b/examples/language/gpt/titans/train_gpt.py @@ -36,9 +36,9 @@ def main(): args = parser.parse_args() disable_existing_loggers() if args.from_torch: - colossalai.launch_from_torch(config=args.config) + colossalai.launch_from_torch() else: - colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) + colossalai.launch_from_slurm(host=args.host, port=29500, seed=42) logger = get_dist_logger() data_path = None if args.use_dummy_dataset else os.environ["DATA"] diff --git a/examples/language/grok-1/inference_tp.py b/examples/language/grok-1/inference_tp.py index e10c4929c..f7d7cf864 100644 --- a/examples/language/grok-1/inference_tp.py +++ b/examples/language/grok-1/inference_tp.py @@ -16,7 +16,7 @@ if __name__ == "__main__": parser = get_default_parser() args = parser.parse_args() start = time.time() - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() coordinator = DistCoordinator() plugin = HybridParallelPlugin( tp_size=coordinator.world_size, diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index f457c08cd..5cc602181 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -78,7 +78,7 @@ def main(): parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) args = parser.parse_args() - colossalai.launch_from_torch({}) + colossalai.launch_from_torch() coordinator = DistCoordinator() def empty_init(): diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index a6d5f8bf2..22e0c790b 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -146,7 +146,7 @@ def main(): args = parse_args() # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) coordinator = DistCoordinator() # Set plugin diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 92f4e066a..40f072f13 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -207,7 +207,7 @@ def main(): args = parse_args() # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) coordinator = DistCoordinator() test_mode = args.model_name == "test" diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index d16c9fdf9..c2883d96c 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -46,7 +46,7 @@ def main(): args = parse_benchmark_args() # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) coordinator = DistCoordinator() world_size = coordinator.world_size diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py index 05336bec4..b5b50305c 100644 --- a/examples/language/opt/opt_train_demo.py +++ b/examples/language/opt/opt_train_demo.py @@ -64,7 +64,7 @@ def main(): args = parse_demo_args() # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(seed=args.seed) coordinator = DistCoordinator() world_size = coordinator.world_size diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 4fac7b507..76a86600b 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -102,7 +102,7 @@ args = parse_args() if args.distplan not in ["colossalai", "pytorch"]: raise TypeError(f"{args.distplan} is error") disable_existing_loggers() -colossalai.launch_from_torch(config={}) +colossalai.launch_from_torch() logger = get_dist_logger() diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py index 29101ce08..b7a3f4320 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py @@ -20,7 +20,7 @@ def _benchmark(rank, world_size, port): only result in minor performance drop. So at last we might be able to find better training batch size for our model (combine with large batch training optimizer such as LAMB). """ - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = tm.resnet152() gm = symbolic_trace(model) raw_graph = deepcopy(gm.graph) diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py index cd03a9179..81ef7ca03 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py @@ -17,7 +17,7 @@ def _benchmark(rank, world_size, port, args): The benchmark will sample in a range of memory budget for each model and output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory. """ - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") if args.model == "resnet50": model = tm.resnet50() data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224)) diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index a4733126f..2b388fe36 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -128,7 +128,7 @@ def main(): # ============================== # Launch Distributed Environment # ============================== - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() coordinator = DistCoordinator() # update the learning rate with linear scaling diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index ec6c852b5..84245d487 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -148,7 +148,7 @@ def main(): # ============================== # Launch Distributed Environment # ============================== - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() coordinator = DistCoordinator() # update the learning rate with linear scaling diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index e97c9017f..624783a79 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -125,7 +125,7 @@ def main(): # ============================== # Launch Distributed Environment # ============================== - colossalai.launch_from_torch(config={}, seed=42) + colossalai.launch_from_torch(seed=42) coordinator = DistCoordinator() # local_batch_size = BATCH_SIZE // coordinator.world_size diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index ae8a0f4a0..cb62f77e1 100644 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -289,7 +289,7 @@ class DummyDataloader: def main(): args = parse_args() disable_existing_loggers() - colossalai.legacy.launch_from_torch(config=dict()) + colossalai.legacy.launch_from_torch() logger = get_dist_logger() is_main_process = dist.get_rank() == 0 diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py index 03bba8e64..14bc7aa57 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py @@ -27,7 +27,7 @@ except: def _run_C_solver_consistency_test(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: model = M() diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py index c46f57f75..19d526524 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py @@ -75,7 +75,7 @@ def check_backward_consistency( def _run_ckpt_solver(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True @@ -111,7 +111,7 @@ def test_ckpt_solver(): def _run_ckpt_solver_torch11(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 373ba28b8..3db7a1925 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -141,8 +141,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_fwd_bwd() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py index c41c66745..f39f09d54 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -42,7 +42,7 @@ class ConvModel(torch.nn.Module): def check_linear_module(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModel(4, 8).cuda() input = torch.rand(4, 4).cuda() output_compare = model(input) @@ -59,7 +59,7 @@ def check_linear_module(rank, world_size, port): def check_conv_module(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvModel(3, 6, 2).cuda() input = torch.rand(4, 3, 64, 64).cuda() output_compare = model(input) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py index c800f54da..f2b966b10 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -39,7 +39,7 @@ class GPT2MLPWithCkpt(nn.Module): def check_act_ckpt(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) torch.rand(1, 64, HIDDEN_SIZE) input_sample = { diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py index e8f175326..202f3e3bf 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -32,7 +32,7 @@ class MLP(torch.nn.Module): def check_compatibility_with_ddp(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MLP(4).cuda() if rank in [0, 1]: input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index d57717326..18de92e2a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -34,7 +34,7 @@ class MLP(torch.nn.Module): def check_auto_parallel_with_gemini(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MLP(4).half().cuda() if rank in [0, 1]: input = torch.arange(0, 16).reshape(4, 4).half().cuda() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index 24968e670..25c5d4ef1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -73,7 +73,7 @@ def _check_module_grad( def check_attention_layer(rank, model_cls, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py index ba9e28214..d2f3e3724 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py @@ -31,7 +31,7 @@ def _binary_elementwise_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda() input = torch.rand(32, 1024).cuda() input.requires_grad = True diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py index 455581545..5495282bc 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -31,7 +31,7 @@ def _conv_module_mem_test(rank, world_size, port, bias): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda() input = torch.rand(4, 4, 64, 64).cuda() input.requires_grad = True @@ -72,7 +72,7 @@ def _conv_function_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvFunctionModule().cuda() input = torch.rand(4, 4, 64, 64).cuda() input.requires_grad = True diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index 639870c89..4958bad6b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -30,7 +30,7 @@ def _linear_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda() input = torch.rand(8, 8, 16, 64).cuda() input.requires_grad = True @@ -68,7 +68,7 @@ def _linear_function_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MyModule().cuda() input = torch.rand(8, 8, 16, 64).cuda() input.requires_grad = True diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py index ed809a758..a0b81edab 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py @@ -25,7 +25,7 @@ def _batchnorm_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.BatchNorm2d(128)).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py index bd1deb40c..92d91383e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py @@ -21,7 +21,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.AdaptiveAvgPool2d((16, 16))).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -62,7 +62,7 @@ def _maxpool_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.MaxPool2d((16, 16))).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py index 73a15f3ba..a8d2fbdfb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -40,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module): def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module(using_kwargs).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -150,7 +150,7 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (1, 4) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py index 26f9c4ab1..60eadeff9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -40,7 +40,7 @@ class AddmmModel_with_param(nn.Module): def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") if model_cls == AddmmModel: model = AddmmModel().cuda() else: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py index 86df7237a..e52cf28ab 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -16,7 +16,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_bn_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.BatchNorm2d(16)).cuda() physical_mesh_id = torch.arange(0, 4) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py index e06625e1c..5982227b6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -34,7 +34,7 @@ class LinearModule(torch.nn.Module): def check_linear_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModule(weight_shape=WEIGHT_SHAPE).cuda() physical_mesh_id = torch.arange(0, 4) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py index 690f0c123..c45e3e014 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -30,7 +30,7 @@ class LinearModule(torch.nn.Module): def check_linear_module_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModule(16, 32, bias=bias).cuda() physical_mesh_id = torch.arange(0, 4) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py index 5b2e2ab49..ad0d6d18c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -16,7 +16,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") class BinaryElementwiseOpModel(nn.Module): def __init__(self, op): @@ -145,7 +145,7 @@ class BEOpModelWithIntConst(nn.Module): def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index 29df12832..ac54f1230 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -26,7 +26,7 @@ class BMMTorchFunctionModule(nn.Module): def check_2d_device_mesh(rank, module, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -121,7 +121,7 @@ def check_2d_device_mesh(rank, module, world_size, port): def check_1d_device_mesh(rank, module, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (1, 4) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index 8a37dd925..407216f46 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -16,7 +16,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_conv_module_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -153,7 +153,7 @@ class ConvModel(nn.Module): def check_conv_function_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvModel().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py index 9ac6ba95d..f9a5b40a0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py @@ -33,7 +33,7 @@ class EmbeddingModule(nn.Module): def check_embedding_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -150,7 +150,7 @@ class EmbeddingFunction(nn.Module): def check_embedding_function_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = EmbeddingFunction().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index cf802a228..eb8e8ed3e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -31,7 +31,7 @@ class GetItemFromTensorModel(nn.Module): def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = GetItemFromTensorModel(getitem_index=getitem_index) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py index 59a66bc6a..45aae2ea9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -17,7 +17,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_ln_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.LayerNorm(16)).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index da88b735f..ddabdb700 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -23,7 +23,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_linear_module_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -171,7 +171,7 @@ class LinearModel(nn.Module): def check_linear_function_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModel().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py index 958dc288f..09ad2ae32 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -51,7 +51,7 @@ class LinearReshapeModel(nn.Module): def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") if call_function == torch.permute: reshape_dims = reshape_dims[0] elif call_function == torch.transpose: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py index 1a99c32eb..88f34ff10 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -29,7 +29,7 @@ class LinearSplitModel(nn.Module): def check_split_handler(rank, world_size, port, softmax_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(softmax_dim=softmax_dim).cuda() input = torch.rand(8, 16, 64, 32).to("cuda") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py index 0318023c8..225a729ef 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -42,7 +42,7 @@ class LinearSplitModel(nn.Module): def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(split_size=split_size, split_dim=split_dim).cuda() if model_cls.__name__ == "ConvSplitModel": diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py index cbd3e4704..a79cfdf6f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -32,7 +32,7 @@ class LinearSumModel(nn.Module): def check_sum_handler(rank, world_size, port, sum_dims, keepdim): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index 466168c79..de483c997 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -41,7 +41,7 @@ class LinearViewModel(nn.Module): def check_view_handler(rank, tgt_shape, model_cls, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(tgt_shape).cuda() if model_cls.__name__ == "ConvViewModel": diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index 3aefb3797..f6d6e8303 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -9,7 +9,7 @@ from tests.kit.model_zoo import model_zoo def run_torch_amp(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") sub_model_zoo = model_zoo.get_sub_registry("timm") for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items(): # dlrm_interactionarch has not parameters, so skip diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 52cb8c46e..e57cadfd8 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -265,7 +265,7 @@ def run_grad_acc_test(test_args): def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_3d_plugin(early_stop=early_stop) run_grad_acc_test() diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py index fceb623fe..a2a4a0c07 100644 --- a/tests/test_booster/test_plugin/test_dp_plugin_base.py +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -85,7 +85,7 @@ def check_dataloader_sharding(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_dataloader_sharding() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 892144772..b2790c0e7 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -161,7 +161,7 @@ def check_gemini_plugin( def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_gemini_plugin(early_stop=early_stop) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index cbfad6ef7..4908b2d4f 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -130,7 +130,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True): def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_plugin(early_stop=early_stop) check_low_level_zero_lora(early_stop=early_stop) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index e785843fb..052782047 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -109,7 +109,7 @@ def check_torch_ddp_no_sync(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_torch_ddp_plugin() check_torch_ddp_no_sync() diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index f69807046..90e98f325 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -73,7 +73,7 @@ def check_torch_fsdp_plugin(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_torch_fsdp_plugin() diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index ac6f8caef..ade927e6e 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -173,8 +173,7 @@ def exam_lazy_from_pretrained(): def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_state_dict_with_origin() exam_lazy_from_pretrained() diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index 44a000113..cd313c240 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -163,8 +163,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_torch_load_from_gemini() exam_gemini_load_from_torch() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 4753ab637..1cf94433d 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -132,8 +132,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 4073cae0c..119e42e31 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -172,7 +172,7 @@ def check_low_level_zero_lora_checkpointIO( def run_dist(rank, world_size, port): - colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_checkpointIO() check_low_level_zero_lora_checkpointIO() torch.cuda.empty_cache() diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index 0353ff115..da0d52d06 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -68,8 +68,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_from_pretrained() diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index eeb04df0f..0b9a1605c 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -61,7 +61,7 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): def run_dist(rank, world_size, port): - colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_torch_ddp_checkpointIO() diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index 1ea70368e..12b70cc04 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -141,7 +141,7 @@ def check_torch_fsdp_ckpt(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_torch_fsdp_ckpt() diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py index ab61cdae5..5d140064b 100644 --- a/tests/test_cluster/test_device_mesh_manager.py +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -6,7 +6,7 @@ from colossalai.testing import spawn def check_device_mesh_manager(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") device_mesh_manager = DeviceMeshManager() # TODO(ver217): this test is strictly relies on hardware, temporary skip it # device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],) diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py index 3d206622d..3071c0f59 100644 --- a/tests/test_cluster/test_process_group_mesh.py +++ b/tests/test_cluster/test_process_group_mesh.py @@ -6,57 +6,6 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.testing import spawn -def check_process_group_mesh_with_gpc(): - from colossalai.legacy.context import ParallelMode - from colossalai.legacy.core import global_context as gpc - - DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 - pg_mesh = ProcessGroupMesh(1, 2, 2) - - # check world size - assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size( - TP_DIM - ), f"{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}" - assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM) - assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM) - - # check locak rank (coordinate) - assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate( - TP_DIM - ), f"{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}" - assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM) - assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM) - - # check ranks in group - tp_group = pg_mesh.get_group_along_axis(TP_DIM) - assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group) - pp_group = pg_mesh.get_group_along_axis(PP_DIM) - assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group) - dp_group = pg_mesh.get_group_along_axis(DP_DIM) - assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group) - - # check prev rank - coord = pg_mesh.coordinate() - if not gpc.is_first_rank(ParallelMode.TENSOR): - assert coord[TP_DIM] != 0 - prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1 :] - assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape) - if not gpc.is_first_rank(ParallelMode.PIPELINE): - assert coord[PP_DIM] != 0 - prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1 :] - assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape) - - # check next rank - if not gpc.is_last_rank(ParallelMode.TENSOR): - assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1 - next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1 :] - assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape) - if not gpc.is_last_rank(ParallelMode.PIPELINE): - assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1 - next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1 :] - assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape) - - def check_process_group_mesh_with_cases(): DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2 @@ -177,14 +126,11 @@ def check_process_group_mesh_with_cases(): def run_dist(rank, world_size, port): colossalai.launch( - config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode="1d", size=2))), rank=rank, world_size=world_size, port=port, host="localhost", ) - # TODO(ver217): this function should be removed when gpc is removed - # check_process_group_mesh_with_gpc() check_process_group_mesh_with_cases() diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py index f4a88f79c..3d9c6d7ce 100644 --- a/tests/test_device/test_alpha_beta.py +++ b/tests/test_device/test_alpha_beta.py @@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) ab_dict = profiler.profile_ab() for _, (alpha, beta) in ab_dict.items(): diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index af44af5d9..b2d057273 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -75,7 +75,7 @@ def check_2d_device_mesh(): def check_init_from_process_group(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") @pytest.mark.dist diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py index 34f2aacc1..7633f59b9 100644 --- a/tests/test_device/test_extract_alpha_beta.py +++ b/tests/test_device/test_extract_alpha_beta.py @@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def check_extract_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 3b398a917..d93f65698 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -9,7 +9,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn def check_layer(rank, world_size, port): - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) assert rank == dist.get_rank() diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py index d9d4e79c1..a44b8e3d6 100644 --- a/tests/test_device/test_search_logical_device_mesh.py +++ b/tests/test_device/test_search_logical_device_mesh.py @@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) best_logical_mesh = profiler.search_best_logical_mesh() diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 10fe98155..8a3e2d6ec 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -64,7 +64,7 @@ class MyModule(torch.nn.Module): def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -127,7 +127,7 @@ def test_act_ckpt_codegen(): def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py index f1e87e5ed..69767db2d 100644 --- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -32,7 +32,7 @@ class MyModule(torch.nn.Module): def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -96,7 +96,7 @@ def test_act_ckpt_codegen(): def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index da1e73ec3..9df4a6899 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -66,7 +66,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T def _run_offload_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and input model = MyNet().cuda() @@ -124,7 +124,7 @@ def test_act_ckpt_codegen(): def _run_offload_codegen_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and input model = MyNet().cuda() diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py index 6d890f59d..6b0e12609 100644 --- a/tests/test_fx/test_parallel_1d.py +++ b/tests/test_fx/test_parallel_1d.py @@ -33,7 +33,7 @@ CONFIG = dict(parallel=dict(tensor=dict(mode="1d", size=2))) def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") input_tensor = torch.rand(2, 16).cuda() model = MLP(16).cuda() symbolic_traced = symbolic_trace(model) diff --git a/tests/test_infer/test_hybrid_bloom.py b/tests/test_infer/test_hybrid_bloom.py index 8cad06dca..ef2aac1d1 100644 --- a/tests/test_infer/test_hybrid_bloom.py +++ b/tests/test_infer/test_hybrid_bloom.py @@ -89,18 +89,18 @@ def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size def check_tp_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_tp_pipeline_inference_test() def check_tp_or_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_tp_inference_test() run_pipeline_inference_test() def check_single_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_single_inference_test diff --git a/tests/test_infer/test_hybrid_chatglm2.py b/tests/test_infer/test_hybrid_chatglm2.py index b53bb25f4..e80b3477f 100644 --- a/tests/test_infer/test_hybrid_chatglm2.py +++ b/tests/test_infer/test_hybrid_chatglm2.py @@ -97,18 +97,18 @@ def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size def check_tp_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_tp_pipeline_inference_test() def check_tp_or_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_tp_inference_test() run_pipeline_inference_test() def check_single_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_single_inference_test diff --git a/tests/test_infer/test_hybrid_llama.py b/tests/test_infer/test_hybrid_llama.py index 30b8b0a99..a99794817 100644 --- a/tests/test_infer/test_hybrid_llama.py +++ b/tests/test_infer/test_hybrid_llama.py @@ -94,18 +94,18 @@ def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size def check_tp_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_tp_pipeline_inference_test() def check_tp_or_pp_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_tp_inference_test() run_pipeline_inference_test() def check_single_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_single_inference_test diff --git a/tests/test_legacy/test_amp/test_naive_fp16.py b/tests/test_legacy/test_amp/test_naive_fp16.py index fe16bc4d4..0df6335f5 100644 --- a/tests/test_legacy/test_amp/test_naive_fp16.py +++ b/tests/test_legacy/test_amp/test_naive_fp16.py @@ -77,7 +77,7 @@ def run_naive_amp(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.legacy.launch(rank=rank, world_size=world_size, port=port, host="localhost") run_naive_amp() diff --git a/tests/test_legacy/test_amp/test_torch_fp16.py b/tests/test_legacy/test_amp/test_torch_fp16.py index 5e2e1ede5..dc47dfc72 100644 --- a/tests/test_legacy/test_amp/test_torch_fp16.py +++ b/tests/test_legacy/test_amp/test_torch_fp16.py @@ -76,7 +76,7 @@ def run_torch_amp(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.legacy.launch(rank=rank, world_size=world_size, port=port, host="localhost") run_torch_amp() diff --git a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py index bc243631a..bd15e10f3 100644 --- a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py +++ b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py @@ -16,7 +16,7 @@ torch.manual_seed(123) def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl", verbose=False) + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl", verbose=False) rank = gpc.get_local_rank(ParallelMode.PIPELINE) if rank == 0: diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py index 079022e93..75955df69 100644 --- a/tests/test_legacy/test_comm/test_comm.py +++ b/tests/test_legacy/test_comm/test_comm.py @@ -48,7 +48,7 @@ def check_all_reduce(): def check_layer(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") assert dist.get_rank() == gpc.get_global_rank() print("Rank {} / {}".format(dist.get_rank(), dist.get_world_size())) diff --git a/tests/test_legacy/test_comm/test_object_list_p2p.py b/tests/test_legacy/test_comm/test_object_list_p2p.py index 69c68c715..1d618a65f 100644 --- a/tests/test_legacy/test_comm/test_object_list_p2p.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p.py @@ -88,7 +88,7 @@ def check_send_recv_forward_backward(): def check_layer(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_send_recv_forward() check_send_recv_backward() check_send_recv_forward_backward() diff --git a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py index eb05ea483..c272f51f4 100644 --- a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py @@ -104,7 +104,7 @@ def check_small_pipeline(): def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") disable_existing_loggers() # check_send_recv_forward() diff --git a/tests/test_legacy/test_layers/test_1d/test_1d.py b/tests/test_legacy/test_layers/test_1d/test_1d.py index cebbedd30..9057c2c68 100644 --- a/tests/test_legacy/test_layers/test_1d/test_1d.py +++ b/tests/test_legacy/test_layers/test_1d/test_1d.py @@ -17,7 +17,7 @@ CONFIG = dict( def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_linear_col() check_linear_row() diff --git a/tests/test_legacy/test_layers/test_2d/test_2d.py b/tests/test_legacy/test_layers/test_2d/test_2d.py index 77a4b281a..5be498f90 100644 --- a/tests/test_legacy/test_layers/test_2d/test_2d.py +++ b/tests/test_legacy/test_layers/test_2d/test_2d.py @@ -50,7 +50,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False diff --git a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py index 437a8f8a7..029274570 100644 --- a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py @@ -38,7 +38,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False diff --git a/tests/test_legacy/test_layers/test_3d/test_3d.py b/tests/test_legacy/test_layers/test_3d/test_3d.py index 7057e2308..876aa7ba8 100644 --- a/tests/test_legacy/test_layers/test_3d/test_3d.py +++ b/tests/test_legacy/test_layers/test_3d/test_3d.py @@ -44,7 +44,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.deterministic = True diff --git a/tests/test_legacy/test_layers/test_cache_embedding.py b/tests/test_legacy/test_layers/test_cache_embedding.py index d64ff56b8..c45097232 100644 --- a/tests/test_legacy/test_layers/test_cache_embedding.py +++ b/tests/test_legacy/test_layers/test_cache_embedding.py @@ -378,7 +378,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.legacy.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # run_parallel_freq_aware_embed_columnwise(rank, world_size) run_parallel_freq_aware_embed_tablewise(rank, world_size) diff --git a/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py index 506244447..bfedb779c 100644 --- a/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py +++ b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py @@ -48,7 +48,7 @@ def check_mem(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.legacy.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_mem() run() diff --git a/tests/test_legacy/test_tensor/test_parameter.py b/tests/test_legacy/test_tensor/test_parameter.py index 5217e22cc..eae3e0eb3 100644 --- a/tests/test_legacy/test_tensor/test_parameter.py +++ b/tests/test_legacy/test_tensor/test_parameter.py @@ -9,7 +9,7 @@ from colossalai.testing import free_port @pytest.mark.skip def test_multiinheritance(): - colossalai.legacy.launch(config={}, rank=0, world_size=1, host="localhost", port=free_port(), backend="nccl") + colossalai.legacy.launch(rank=0, world_size=1, host="localhost", port=free_port(), backend="nccl") colo_param = ColoParameter(None, requires_grad=True) assert colo_param.dist_spec.placement.value == "r" assert isinstance(colo_param, ColoTensor) diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py index cab111358..ba8504d06 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py @@ -86,7 +86,7 @@ def check_comm(size, rank, prev_rank, next_rank, logger): def run_check(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") logger = get_dist_logger() rank = gpc.get_global_rank() prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py index cd7fcfe56..ae7b961ae 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -23,7 +23,7 @@ CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=d def run_schedule(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model model = resnet18(num_classes=10) diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py index c07ff132b..e1b2128aa 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -43,7 +43,7 @@ def check_checkpoint_1d(rank, world_size, port): ) disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) sd1 = m1.state_dict() diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py index 2ec1facf2..12747951b 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -43,7 +43,7 @@ def check_checkpoint_2d(rank, world_size, port): ) disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) sd1 = m1.state_dict() diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py index a6bf702a8..f7e7b6fad 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -43,7 +43,7 @@ def check_checkpoint_2p5d(rank, world_size, port): ) disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) sd1 = m1.state_dict() diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py index 12d928312..05666cc93 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -43,7 +43,7 @@ def check_checkpoint_3d(rank, world_size, port): ) disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) sd1 = m1.state_dict() diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py index 4993df4f3..30fc17b8e 100644 --- a/tests/test_legacy/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -14,7 +14,7 @@ def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.legacy.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py index 9975cc04f..c5fab49f4 100644 --- a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py @@ -62,7 +62,7 @@ def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_ty def run_dist(rank, world_size, port): disable_existing_loggers() - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.legacy.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_grad_clip_norm(world_size=world_size) diff --git a/tests/test_legacy/test_zero/test_commons.py b/tests/test_legacy/test_zero/test_commons.py index 741f519e1..32b15706d 100644 --- a/tests/test_legacy/test_zero/test_commons.py +++ b/tests/test_legacy/test_zero/test_commons.py @@ -7,7 +7,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn def run_tensor_move(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=0, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.legacy.launch(rank=0, world_size=world_size, host="localhost", port=port, backend="nccl") src_t = torch.ones(2, 3).cuda() tgt_t = torch.zeros(2, 3) diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index 69febff38..b8daf775d 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -96,8 +96,7 @@ def run_lora_test(): def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_lora_test() diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index a349bc5a9..a88f5f9cc 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -16,7 +16,6 @@ DIM = 16 def run_test(rank, world_size, port): colossalai.launch( - config=dict(), rank=rank, world_size=world_size, host="localhost", diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 62d61a3d4..30122d31a 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -20,7 +20,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # Here we do not need TF32, since it brings absolute error on results torch.backends.cuda.matmul.allow_tf32 = False - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") local_rank = dist.get_rank() MOE_MANAGER.setup(parallel="EP") # MOE environment initialization diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 74feeeb59..660fbd358 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -128,7 +128,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict): assert batch_size % world_size == 0 - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel=None) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 2f08a335d..b7be54d26 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -60,7 +60,6 @@ def run_moe_init(expert_parallel): def _run_test(rank, world_size, port, expert_parallel): colossalai.launch( - config=dict(), rank=rank, world_size=world_size, host="localhost", diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py index 7ada4090f..7932fa8a7 100644 --- a/tests/test_moe/test_moe_hybrid_zero.py +++ b/tests/test_moe/test_moe_hybrid_zero.py @@ -81,7 +81,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_zero_optim_test(rank, world_size, stage=1) run_zero_optim_test(rank, world_size, stage=2) diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 717bb99fb..fae189bac 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -164,7 +164,6 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch( - config=dict(), rank=rank, world_size=world_size, host="localhost", diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 1bff21066..3bb08b49e 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -61,7 +61,7 @@ def run_zero_test(local_rank, stage=1): def run_dist(rank, world_size, port, stage): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") seed_all(42 + rank) run_zero_test(rank, stage=stage) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 4f6067aaa..224c5c3b9 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -66,7 +66,7 @@ def run_zero_test(local_rank, stage=1): def run_dist(rank, world_size, port, stage): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") seed_all(42 + rank) run_zero_test(rank, stage=stage) diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 6d932156a..002649905 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -69,7 +69,7 @@ class FusedAdamKernel(AdamKernel): fused_optim = FusedOptimizerLoader().load() self.fused_adam = fused_optim.multi_tensor_adam - self.dummy_overflow_buf = torch.cuda.IntTensor([0]) + self.dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): multi_tensor_applier( diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 6f5e734b7..48a8d12e0 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -71,7 +71,7 @@ def check_p2p_communication(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_p2p_communication() diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index f8820688e..a626b834a 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -58,7 +58,7 @@ def run_pp( This test is to examine the correctness of interleaved 1F1B, compared with torch. Be aware it contains some hardcodes. """ - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") # create model seed_all(1453) diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 590800780..c4bfa7b69 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -148,7 +148,7 @@ def run_dist( num_microbatch: int, batch_size: int, ): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") examine_pp(num_microbatch, batch_size) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index ed8284b3e..5146a86c8 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -64,7 +64,7 @@ def check_stage_manager(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_stage_manager() diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py index f652d18e9..b2c81f8ab 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py @@ -193,13 +193,13 @@ def run_3d_test(test_config): def check_grad_clip_norm(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_test() def check_grad_clip_norm_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_3d_test() diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py index a749a2966..ee1fd9333 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py @@ -151,13 +151,13 @@ def run_3d_test(test_config): def check_grad_clip_norm(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_test() def check_grad_clip_norm_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_3d_test() diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py index 41f06a4c3..be257e818 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py @@ -183,13 +183,13 @@ def run_3d_test(test_config): def check_grad_clip_norm(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_test() def check_grad_clip_norm_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_3d_test() diff --git a/tests/test_shardformer/test_layer/test_dist_crossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py index 414157c22..8ace0e028 100644 --- a/tests/test_shardformer/test_layer/test_dist_crossentropy.py +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -14,7 +14,7 @@ CONFIG = dict( def check_dist_crossentropy(rank, world_size, port, ignore_index): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") # prepare data pred = torch.randn(2, 4, 8, requires_grad=True).cuda() diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py index 576620e6c..f1e646ed2 100644 --- a/tests/test_shardformer/test_layer/test_dropout.py +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -56,7 +56,7 @@ def check_dropout_replicated_input(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_dropout_parallel_input() check_dropout_replicated_input() diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 3dbbcd766..3d7dc2088 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -43,7 +43,7 @@ def check_embedding_1d(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_embedding_1d() diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index e9aa0dbed..5aa8584a0 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -143,7 +143,7 @@ def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, ove def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # test for linear conv check_gpt2_qkv_fused_linear_1d() diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index 3eb3bb2e5..b0deff6b8 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -41,7 +41,7 @@ def check_layernorm(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_layernorm() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 21d3190de..541aa3251 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -185,7 +185,7 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap): def check_dist_linear(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_dist_linear_test() diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 5e996d2ba..dc14fd591 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -126,7 +126,7 @@ def check_linear_conv_1d_row(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # test for linear conv check_linear_conv_1d_col() diff --git a/tests/test_shardformer/test_layer/test_sequence_parallel.py b/tests/test_shardformer/test_layer/test_sequence_parallel.py index 13b1a13e7..a6cf61f8f 100644 --- a/tests/test_shardformer/test_layer/test_sequence_parallel.py +++ b/tests/test_shardformer/test_layer/test_sequence_parallel.py @@ -165,7 +165,7 @@ def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): def check_all2all_attn(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_seq_parallel_attn() diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index 91cc1a987..fdd304256 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -45,7 +45,7 @@ def check_vocab_embedding_1d(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_vocab_embedding_1d() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 919557797..3ec394768 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -231,13 +231,13 @@ def run_bert_3d_test(test_config): def check_bert(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bert_test() def check_bert_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bert_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py index 2c56b0435..712c5c1e1 100644 --- a/tests/test_shardformer/test_model/test_shard_blip2.py +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -99,7 +99,6 @@ def run_blip2_test( def check_blip2(rank, world_size, port): disable_existing_loggers() colossalai.launch( - config={}, rank=rank, world_size=world_size, host="localhost", diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index cc0786618..6ab0369e0 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -209,13 +209,13 @@ def run_bloom_3d_test(test_config): def check_bloom(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bloom_test() def check_bloom_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bloom_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 376d315c1..6ce020b68 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -259,7 +259,6 @@ def run_chatglm_3d_test(test_config): def check_chatglm(rank, world_size, port): disable_existing_loggers() colossalai.launch( - config={}, rank=rank, world_size=world_size, host="localhost", @@ -272,7 +271,6 @@ def check_chatglm(rank, world_size, port): def check_chatglm_3d(rank, world_size, port): disable_existing_loggers() colossalai.launch( - config={}, rank=rank, world_size=world_size, host="localhost", diff --git a/tests/test_shardformer/test_model/test_shard_falcon.py b/tests/test_shardformer/test_model/test_shard_falcon.py index 5e2efcd80..8074f9d61 100644 --- a/tests/test_shardformer/test_model/test_shard_falcon.py +++ b/tests/test_shardformer/test_model/test_shard_falcon.py @@ -176,13 +176,13 @@ def run_falcon_3d_test(test_config): def check_falcon(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_falcon_test() def check_falcon_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_falcon_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 4aac7f3d4..72ea2b089 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -275,7 +275,6 @@ def run_gpt2_3d_test(test_config): def check_gpt2(rank, world_size, port): disable_existing_loggers() colossalai.launch( - config={}, rank=rank, world_size=world_size, host="localhost", @@ -288,7 +287,6 @@ def check_gpt2(rank, world_size, port): def check_gpt2_3d(rank, world_size, port): disable_existing_loggers() colossalai.launch( - config={}, rank=rank, world_size=world_size, host="localhost", diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 394592688..104ede981 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -319,13 +319,13 @@ def run_llama_3d_test(test_config): def check_llama(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_test() def check_llama_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index 05c199814..deced9d56 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -170,7 +170,7 @@ def run_mistral_test(test_config): def check_mistral(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_mistral_test() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 523ed879b..b7c77d20b 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -233,7 +233,6 @@ def run_opt_3d_test(test_config): def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch( - config={}, rank=rank, world_size=world_size, host="localhost", @@ -246,7 +245,6 @@ def check_OPTModel(rank, world_size, port): def check_opt_3d(rank, world_size, port): disable_existing_loggers() colossalai.launch( - config={}, rank=rank, world_size=world_size, host="localhost", diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py index a8d4cb635..e872d7f7b 100644 --- a/tests/test_shardformer/test_model/test_shard_sam.py +++ b/tests/test_shardformer/test_model/test_shard_sam.py @@ -57,7 +57,7 @@ def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_f def check_sam(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_sam_test() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index a6fe2dd39..521dc9130 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -222,7 +222,6 @@ def run_t5_3d_test(test_config): def check_t5(rank, world_size, port): disable_existing_loggers() colossalai.launch( - config={}, rank=rank, world_size=world_size, host="localhost", @@ -235,7 +234,6 @@ def check_t5(rank, world_size, port): def check_t5_3d(rank, world_size, port): disable_existing_loggers() colossalai.launch( - config={}, rank=rank, world_size=world_size, host="localhost", diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 3a8af2d6d..d33b52b42 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -168,13 +168,13 @@ def run_vit_3d_test(test_config): def check_vit(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_vit_test() def check_vit_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_vit_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index af61e4640..beb2a6761 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -196,13 +196,13 @@ def run_whisper_3d_test(test_config): def check_whisper(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_whisper_test() def check_whisper_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_whisper_3d_test() diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index 4b741c21b..4735df717 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -71,7 +71,7 @@ def check_shardformer_with_ddp(lazy_init: bool): def run_dist(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_shardformer_with_ddp() diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py index 5e969b1aa..a2414d949 100644 --- a/tests/test_tensor/test_comm_spec_apply.py +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -178,7 +178,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) assert rank == dist.get_rank() diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index 6d1640b4f..fd9996710 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -124,7 +124,7 @@ def check_all_reduce_bwd(process_groups_dict, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) assert rank == dist.get_rank() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 33ae59d01..60efa315e 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -21,7 +21,7 @@ class TestModel(torch.nn.Module): def check_dtensor(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") test_model = TestModel(8, 8).to("cuda") original_tensor = torch.rand(4, 8).to("cuda") compare_output = test_model(original_tensor) diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 3bface1d2..6e426d0e8 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -20,7 +20,7 @@ mesh_shape = (2, 2) def check_one_step_transform(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # [[0, 1], # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -82,7 +82,7 @@ def check_one_step_transform(rank, world_size, port): def check_layout_converting(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -141,7 +141,7 @@ def check_layout_converting(rank, world_size, port): def check_layout_converting_apply(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py index 7d6f8979d..6dbbe5de6 100644 --- a/tests/test_tensor/test_mix_gather.py +++ b/tests/test_tensor/test_mix_gather.py @@ -296,7 +296,7 @@ def check_two_all_gather_RS01(device_mesh, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 8) assert rank == dist.get_rank() diff --git a/tests/test_tensor/test_padded_tensor.py b/tests/test_tensor/test_padded_tensor.py index 31a267c15..6d19845df 100644 --- a/tests/test_tensor/test_padded_tensor.py +++ b/tests/test_tensor/test_padded_tensor.py @@ -10,7 +10,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn def check_padded_tensor(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") original_tensor = torch.rand(32, 64).to("cuda") device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py index b2bc84edd..8d8d8ef51 100644 --- a/tests/test_tensor/test_shape_consistency_apply.py +++ b/tests/test_tensor/test_shape_consistency_apply.py @@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn def check_apply(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py index 879eeccde..412a95f6a 100644 --- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py +++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py @@ -49,7 +49,7 @@ def exam_chunk_memory(keep_gathered, pin_memory): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_chunk_memory() diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index e4dc569b8..257311328 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -108,7 +108,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_chunk_basic() diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 3a9742e01..d9084fd5a 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -100,8 +100,7 @@ def exam_gpt_fwd_bwd( def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_gpt_fwd_bwd() diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index 90ad62d1a..1e49f2851 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -80,8 +80,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_gemini_use_rmt() diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index 36a803492..fd0e9fd7c 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -138,8 +138,7 @@ def exam_gemini_grad_acc( def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_gemini_grad_acc() diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 23b3504fd..0a9bac092 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -117,8 +117,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_grad_clipping() diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 7f3c7176e..e54804fc5 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -107,8 +107,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_inference() diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 71bb27b4a..a9366e7bc 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -183,8 +183,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_model_step() exam_tiny_example() diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index cf3658bf9..9c8c497f3 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -47,7 +47,7 @@ def exam_chunk_manager(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_search_chunk_size() exam_chunk_manager() diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index cbf5169fc..23e2d8083 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -76,8 +76,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 87cb1cdfe..8d70ae3b1 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -68,8 +68,7 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_zero_optim_state_dict() diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 11f738615..ed12bb72d 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -130,7 +130,7 @@ def exam_zero_1_grad_acc(sync): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_grad_acc(sync=True) exam_zero_1_grad_acc(sync=False) diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index e2196cfbf..06a29bd1d 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -178,7 +178,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_torch_ddp(world_size=world_size) exam_zero_1_2() diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py index e9fc8598a..8543dfba0 100644 --- a/tests/test_zero/test_low_level/test_zero_ckpt.py +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -103,7 +103,7 @@ def exam_zero_1_torch_ddp_ckpt(): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_torch_ddp_ckpt() From 6af6d6fc9fe72997af44cdf3cb7b930a365ab915 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 29 Apr 2024 15:33:51 +0800 Subject: [PATCH 26/28] [shardformer] support bias_gelu_jit_fused for models (#5647) * support gelu_bias_fused for gpt2 * support gelu_bias_fused for gpt2 fix fix fix * fix fix * fix --- colossalai/shardformer/modeling/bert.py | 13 +++++++++++++ colossalai/shardformer/modeling/blip2.py | 14 ++++++++++++++ colossalai/shardformer/modeling/gpt2.py | 15 +++++++++++++++ colossalai/shardformer/modeling/vit.py | 12 ++++++++++++ colossalai/shardformer/policies/bert.py | 12 ++++++++++++ colossalai/shardformer/policies/blip2.py | 14 ++++++++++++++ colossalai/shardformer/policies/gpt2.py | 15 ++++++++++++++- colossalai/shardformer/policies/vit.py | 22 +++++++++++++++++++++- 8 files changed, 115 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 0838fcee6..e7679f0ec 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1287,3 +1287,16 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): ) return forward + + +def get_jit_fused_bert_intermediate_forward(): + from transformers.models.bert.modeling_bert import BertIntermediate + + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, bias = self.dense(hidden_states) + hidden_states = JitGeLUFunction.apply(hidden_states, bias) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index bd84c87c6..96e8a9d0c 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -129,3 +129,17 @@ def get_jit_fused_blip2_QFormer_output_forward(): return hidden_states return forward + + +def get_jit_fused_blip2_mlp_forward(): + from transformers.models.blip_2.modeling_blip_2 import Blip2MLP + + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, bias = self.fc1(hidden_states) + hidden_states = JitGeLUFunction.apply(hidden_states, bias) + hidden_states = self.fc2(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 17acdf7fc..bfa995645 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -1310,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) return forward + + +def get_jit_fused_gpt2_mlp_forward(): + from transformers.models.gpt2.modeling_gpt2 import GPT2MLP + + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states, bias = self.c_fc(hidden_states) + hidden_states = JitGeLUFunction.apply(hidden_states, bias) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 67b10988d..b1a5c4143 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -372,3 +372,15 @@ def get_jit_fused_vit_output_forward(): return hidden_states return forward + + +def get_jit_fused_vit_intermediate_forward(): + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, bias = self.dense(hidden_states) + hidden_states = JitGeLUFunction.apply(hidden_states, bias) + + return hidden_states + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index d43fc893a..ad40e0e56 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -12,6 +12,7 @@ from ..modeling.bert import ( BertPipelineForwards, bert_sequence_parallel_forward_fn, get_bert_flash_attention_forward, + get_jit_fused_bert_intermediate_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, ) @@ -38,11 +39,13 @@ class BertPolicy(Policy): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu" return self.model def module_policy(self): from transformers.models.bert.modeling_bert import ( BertEmbeddings, + BertIntermediate, BertLayer, BertModel, BertOutput, @@ -131,6 +134,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "overlap": overlap, + "skip_bias_add": self.enable_bias_gelu_fused, }, ), SubModuleReplacementDescription( @@ -153,6 +157,14 @@ class BertPolicy(Policy): ), ] ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_intermediate_forward(), + }, + policy=policy, + target_key=BertIntermediate, + ) if sp_mode == "split_gather": self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index b845e9336..9d1f6a306 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -3,6 +3,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.blip2 import ( forward_fn, get_blip2_flash_attention_forward, + get_jit_fused_blip2_mlp_forward, get_jit_fused_blip2_QFormer_output_forward, get_jit_fused_blip2_QFormer_self_output_forward, ) @@ -18,12 +19,16 @@ class BlipPolicy(Policy): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.enable_bias_gelu_fused = ( + self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu" + ) return self.model def module_policy(self): from transformers.models.blip_2.modeling_blip_2 import ( Blip2Attention, Blip2EncoderLayer, + Blip2MLP, Blip2QFormerLayer, Blip2QFormerModel, Blip2QFormerOutput, @@ -73,6 +78,7 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="mlp.fc1", target_module=col_nn.Linear1D_Col, + kwargs={"skip_bias_add": self.enable_bias_gelu_fused}, ), SubModuleReplacementDescription( suffix="mlp.fc2", @@ -201,6 +207,14 @@ class BlipPolicy(Policy): ) policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_mlp_forward(), + }, + policy=policy, + target_key=Blip2MLP, + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6f4f835a8..531c2153b 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -10,6 +10,7 @@ from ..modeling.gpt2 import ( GPT2PipelineForwards, get_gpt2_flash_attention_forward, get_gpt_model_forward_for_flash_attn, + get_jit_fused_gpt2_mlp_forward, get_lm_forward_with_dist_cross_entropy, gpt2_sequence_parallel_forward_fn, ) @@ -36,10 +37,13 @@ class GPT2Policy(Policy): """ self.tie_weight = self.tie_weight_check() self.origin_attn_implement = self.model.config._attn_implementation + self.enable_bias_gelu_fused = ( + self.shard_config.enable_jit_fused and self.model.config.activation_function == "gelu" + ) return self.model def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model + from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model ATTN_IMPLEMENTATION = { "eager": GPT2Attention, @@ -119,6 +123,7 @@ class GPT2Policy(Policy): "n_fused": 1, "seq_parallel_mode": sp_mode, "overlap": overlap, + "skip_bias_add": self.enable_bias_gelu_fused, }, ), SubModuleReplacementDescription( @@ -142,6 +147,14 @@ class GPT2Policy(Policy): ), ], ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_gpt2_mlp_forward(), + }, + policy=policy, + target_key=GPT2MLP, + ) if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by self.append_or_create_submodule_replacement( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 905398c4d..b7883af9f 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -11,6 +11,7 @@ from ..modeling.vit import ( ViTForImageClassification_pipeline_forward, ViTForMaskedImageModeling_pipeline_forward, ViTModel_pipeline_forward, + get_jit_fused_vit_intermediate_forward, get_jit_fused_vit_output_forward, get_vit_flash_self_attention_forward, ) @@ -24,10 +25,17 @@ class ViTPolicy(Policy): pass def preprocess(self): + self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu" return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention + from transformers.models.vit.modeling_vit import ( + ViTEmbeddings, + ViTIntermediate, + ViTLayer, + ViTOutput, + ViTSelfAttention, + ) policy = {} @@ -83,6 +91,9 @@ class ViTPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, + kwargs={ + "skip_bias_add": self.enable_bias_gelu_fused, + }, ), SubModuleReplacementDescription( suffix="output.dense", @@ -94,6 +105,14 @@ class ViTPolicy(Policy): ), ], ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_vit_intermediate_forward(), + }, + policy=policy, + target_key=ViTIntermediate, + ) # use flash attention if self.shard_config.enable_flash_attention: @@ -115,6 +134,7 @@ class ViTPolicy(Policy): policy=policy, target_key=ViTOutput, ) + return policy def new_model_class(self): From d3f34ee8cc48b089c8b7dbc55697f77719f33079 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Mon, 29 Apr 2024 05:47:47 -0500 Subject: [PATCH 27/28] [Shardformer] add assert for num of attention heads divisible by tp_size (#5670) * add assert for num of attention heads divisible by tp_size * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/shardformer/policies/bert.py | 3 +++ colossalai/shardformer/policies/blip2.py | 3 +++ colossalai/shardformer/policies/bloom.py | 3 +++ colossalai/shardformer/policies/falcon.py | 6 ++++++ colossalai/shardformer/policies/gpt2.py | 3 +++ colossalai/shardformer/policies/gptj.py | 3 +++ colossalai/shardformer/policies/llama.py | 6 ++++++ colossalai/shardformer/policies/mistral.py | 6 ++++++ colossalai/shardformer/policies/opt.py | 3 +++ colossalai/shardformer/policies/sam.py | 3 +++ colossalai/shardformer/policies/t5.py | 3 +++ colossalai/shardformer/policies/vit.py | 3 +++ colossalai/shardformer/policies/whisper.py | 3 +++ 13 files changed, 48 insertions(+) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ad40e0e56..0c04f7d38 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -79,6 +79,9 @@ class BertPolicy(Policy): sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[BertLayer] = ModulePolicyDescription( attribute_replacement={ "attention.self.all_head_size": self.model.config.hidden_size diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 9d1f6a306..32d4edadb 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -52,6 +52,9 @@ class BlipPolicy(Policy): norm_cls = col_nn.LayerNorm if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[Blip2EncoderLayer] = ModulePolicyDescription( attribute_replacement={ "self_attn.num_heads": self.model.config.vision_config.num_attention_heads diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 4894bda35..4f076d233 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -61,6 +61,9 @@ class BloomPolicy(Policy): sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.n_head % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[BloomBlock] = ModulePolicyDescription( attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index e72a97e4b..23d6efbeb 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -47,6 +47,12 @@ class FalconPolicy(Policy): embedding_cls = col_nn.PaddingEmbedding if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_kv_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." attn_attribute_replacement = { "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 531c2153b..281ea88c2 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -84,6 +84,9 @@ class GPT2Policy(Policy): self.shard_config.enable_flash_attention = False use_flash_attention = False if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 25e5b66dc..3315eb1e9 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -57,6 +57,9 @@ class GPTJPolicy(Policy): overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[GPTJModel] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 0a95284bc..6e541f792 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -138,6 +138,12 @@ class LlamaPolicy(Policy): ) if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index b5018e47d..984b71646 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -66,6 +66,12 @@ class MistralPolicy(Policy): ) if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 2f6eabd5f..9619b3d41 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -76,6 +76,9 @@ class OPTPolicy(Policy): warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[OPTDecoderLayer] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index ce33925ff..c224d7769 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -31,6 +31,9 @@ class SamPolicy(Policy): norm_cls = col_nn.LayerNorm if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[SamVisionLayer] = ModulePolicyDescription( attribute_replacement={ "attn.num_attention_heads": self.model.config.vision_config.num_attention_heads diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 3c7e92b47..1298f0af3 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -72,6 +72,9 @@ class T5BasePolicy(Policy): warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[T5Stack] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index b7883af9f..069ad0c26 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -44,6 +44,9 @@ class ViTPolicy(Policy): warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[ViTEmbeddings] = ModulePolicyDescription( attribute_replacement={}, param_replacement=[], diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index aeb668797..441e512bb 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -78,6 +78,9 @@ class WhisperPolicy(Policy): warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.") if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.encoder_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[WhisperEncoderLayer] = ModulePolicyDescription( attribute_replacement={ "self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, From 8754abae24dbcc492d2992d1091428592b615285 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao Date: Sun, 5 May 2024 16:28:56 +0000 Subject: [PATCH 28/28] [Fix] Fix & Update Inference Tests (compatibility w/ main) --- colossalai/inference/modeling/models/nopadding_llama.py | 4 ++-- .../benchmark_ops/benchmark_context_attn_unpad.py | 2 +- .../inference/benchmark_ops/benchmark_decoding_attn.py | 4 ++-- .../benchmark_ops/benchmark_flash_decoding_attention.py | 2 +- .../benchmark_ops/benchmark_fused_rotary_embdding_unpad.py | 2 +- .../inference/benchmark_ops/benchmark_kv_cache_memcopy.py | 4 ++-- examples/inference/benchmark_ops/benchmark_xine_copy.py | 2 +- tests/test_infer/test_config_and_struct.py | 2 +- tests/test_infer/test_cuda_graph.py | 2 +- tests/test_infer/test_inference_engine.py | 2 +- tests/test_infer/{test_ops => test_kernels}/__init__.py | 0 .../test_infer/{test_ops => test_kernels}/cuda/__init__.py | 0 .../cuda/test_flash_decoding_attention.py | 4 ++-- .../cuda/test_get_cos_and_sin.py | 2 +- .../cuda/test_kv_cache_memcpy.py | 5 ++++- .../{test_ops => test_kernels}/cuda/test_rms_layernorm.py | 0 .../cuda/test_rotary_embdding_unpad.py | 4 ++-- .../{test_ops => test_kernels}/cuda/test_silu_and_mul.py | 0 .../{test_ops => test_kernels}/triton/__init__.py | 0 .../{test_ops => test_kernels}/triton/kernel_utils.py | 0 .../triton/test_context_attn_unpad.py | 2 +- .../triton/test_decoding_attn.py | 4 ++-- .../triton/test_fused_rotary_embedding.py | 0 .../{test_ops => test_kernels}/triton/test_kvcache_copy.py | 2 +- .../triton/test_rmsnorm_triton.py | 0 .../triton/test_rotary_embdding_unpad.py | 2 +- .../{test_ops => test_kernels}/triton/test_xine_copy.py | 0 tests/test_infer/test_kvcache_manager.py | 2 +- tests/test_infer/test_models/test_baichuan.py | 7 +++---- tests/test_infer/test_request_handler.py | 2 +- 30 files changed, 32 insertions(+), 30 deletions(-) rename tests/test_infer/{test_ops => test_kernels}/__init__.py (100%) rename tests/test_infer/{test_ops => test_kernels}/cuda/__init__.py (100%) rename tests/test_infer/{test_ops => test_kernels}/cuda/test_flash_decoding_attention.py (98%) rename tests/test_infer/{test_ops => test_kernels}/cuda/test_get_cos_and_sin.py (95%) rename tests/test_infer/{test_ops => test_kernels}/cuda/test_kv_cache_memcpy.py (97%) rename tests/test_infer/{test_ops => test_kernels}/cuda/test_rms_layernorm.py (100%) rename tests/test_infer/{test_ops => test_kernels}/cuda/test_rotary_embdding_unpad.py (96%) rename tests/test_infer/{test_ops => test_kernels}/cuda/test_silu_and_mul.py (100%) rename tests/test_infer/{test_ops => test_kernels}/triton/__init__.py (100%) rename tests/test_infer/{test_ops => test_kernels}/triton/kernel_utils.py (100%) rename tests/test_infer/{test_ops => test_kernels}/triton/test_context_attn_unpad.py (99%) rename tests/test_infer/{test_ops => test_kernels}/triton/test_decoding_attn.py (97%) rename tests/test_infer/{test_ops => test_kernels}/triton/test_fused_rotary_embedding.py (100%) rename tests/test_infer/{test_ops => test_kernels}/triton/test_kvcache_copy.py (99%) rename tests/test_infer/{test_ops => test_kernels}/triton/test_rmsnorm_triton.py (100%) rename tests/test_infer/{test_ops => test_kernels}/triton/test_rotary_embdding_unpad.py (98%) rename tests/test_infer/{test_ops => test_kernels}/triton/test_xine_copy.py (100%) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 557ca0d12..5b8b43d4e 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -270,7 +270,7 @@ def llama_rmsnorm_forward( return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual) -class NopadLlamaMLP(ParallelModule, LlamaMLP): +class NopadLlamaMLP(LlamaMLP, ParallelModule): def __init__( self, config: LlamaConfig, @@ -392,7 +392,7 @@ class NopadLlamaMLP(ParallelModule, LlamaMLP): return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False" -class NopadLlamaAttention(ParallelModule, LlamaAttention): +class NopadLlamaAttention(LlamaAttention, ParallelModule): def __init__( self, config: LlamaConfig, diff --git a/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py b/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py index 498282ba3..18fe76cf0 100644 --- a/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py +++ b/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py @@ -4,7 +4,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref +from tests.test_infer.test_kernels.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref try: import triton # noqa diff --git a/examples/inference/benchmark_ops/benchmark_decoding_attn.py b/examples/inference/benchmark_ops/benchmark_decoding_attn.py index 1a80961a7..4471ddada 100644 --- a/examples/inference/benchmark_ops/benchmark_decoding_attn.py +++ b/examples/inference/benchmark_ops/benchmark_decoding_attn.py @@ -2,14 +2,14 @@ import torch from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import ( +from tests.test_infer.test_kernels.triton.kernel_utils import ( convert_kv_unpad_to_padded, create_attention_mask, generate_caches_and_block_tables_v2, generate_caches_and_block_tables_v3, torch_attn_ref, ) -from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data +from tests.test_infer.test_kernels.triton.test_decoding_attn import prepare_data try: import triton # noqa diff --git a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py index 35eae69b6..d90de6664 100644 --- a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py +++ b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py @@ -3,7 +3,7 @@ import torch from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import ( +from tests.test_infer.test_kernels.triton.kernel_utils import ( generate_caches_and_block_tables_v2, generate_caches_and_block_tables_v3, generate_caches_and_block_tables_vllm, diff --git a/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py index 6a499ccf2..80939f5a1 100644 --- a/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py +++ b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py @@ -2,7 +2,7 @@ import torch from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding -from tests.test_infer.test_ops.triton.kernel_utils import ( +from tests.test_infer.test_kernels.triton.kernel_utils import ( mock_alloc_block_table_and_kvcache_v2, mock_alloc_block_table_and_kvcache_v3, mock_alloc_single_token, diff --git a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py index 03f797308..0232cb90e 100644 --- a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py +++ b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py @@ -4,8 +4,8 @@ from colossalai.inference.modeling.layers.attention import copy_to_cache from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import copy_kv_to_blocked_cache from colossalai.utils import get_current_device -from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout -from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data +from tests.test_infer.test_kernels.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout +from tests.test_infer.test_kernels.triton.test_kvcache_copy import prepare_data try: import triton # noqa diff --git a/examples/inference/benchmark_ops/benchmark_xine_copy.py b/examples/inference/benchmark_ops/benchmark_xine_copy.py index b15232b91..633ceb6f1 100644 --- a/examples/inference/benchmark_ops/benchmark_xine_copy.py +++ b/examples/inference/benchmark_ops/benchmark_xine_copy.py @@ -1,7 +1,7 @@ import torch from colossalai.kernel.triton import get_xine_cache -from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin +from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin try: import triton # noqa diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py index 046ee932d..cc0389af9 100755 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -80,7 +80,7 @@ def check_config_and_inference(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_config_and_inference() diff --git a/tests/test_infer/test_cuda_graph.py b/tests/test_infer/test_cuda_graph.py index a0a55d3ad..4cdc62fbe 100644 --- a/tests/test_infer/test_cuda_graph.py +++ b/tests/test_infer/test_cuda_graph.py @@ -80,7 +80,7 @@ def check_output_consistency(batch_size): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_output_consistency(32) check_output_consistency(64) check_output_consistency(128) diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 25413a292..a0ddbbc7b 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -157,7 +157,7 @@ def check_spec_dec(num_layers, max_length): def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") if ret: ret[rank] = func_to_run(**kwargs) diff --git a/tests/test_infer/test_ops/__init__.py b/tests/test_infer/test_kernels/__init__.py similarity index 100% rename from tests/test_infer/test_ops/__init__.py rename to tests/test_infer/test_kernels/__init__.py diff --git a/tests/test_infer/test_ops/cuda/__init__.py b/tests/test_infer/test_kernels/cuda/__init__.py similarity index 100% rename from tests/test_infer/test_ops/cuda/__init__.py rename to tests/test_infer/test_kernels/cuda/__init__.py diff --git a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py similarity index 98% rename from tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py rename to tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py index b3bd503bb..80a5d067b 100644 --- a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py +++ b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py @@ -7,11 +7,11 @@ import torch from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask +from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask inference_ops = InferenceOpsLoader().load() -from tests.test_infer.test_ops.triton.kernel_utils import ( +from tests.test_infer.test_kernels.triton.kernel_utils import ( convert_kv_unpad_to_padded, create_attention_mask, generate_caches_and_block_tables_v3, diff --git a/tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py b/tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py similarity index 95% rename from tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py rename to tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py index c632cfe30..b6ba1a01b 100644 --- a/tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py +++ b/tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py @@ -3,7 +3,7 @@ import pytest import torch from colossalai.kernel.kernel_loader import InferenceOpsLoader -from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin +from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin inference_ops = InferenceOpsLoader().load() diff --git a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py similarity index 97% rename from tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py rename to tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py index e9c99ddc7..d90f64690 100644 --- a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py +++ b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py @@ -4,7 +4,10 @@ import torch.nn.functional as F from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token +from tests.test_infer.test_kernels.triton.kernel_utils import ( + generate_caches_and_block_tables_v3, + mock_alloc_single_token, +) inference_ops = InferenceOpsLoader().load() diff --git a/tests/test_infer/test_ops/cuda/test_rms_layernorm.py b/tests/test_infer/test_kernels/cuda/test_rms_layernorm.py similarity index 100% rename from tests/test_infer/test_ops/cuda/test_rms_layernorm.py rename to tests/test_infer/test_kernels/cuda/test_rms_layernorm.py diff --git a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py similarity index 96% rename from tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py rename to tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 501bf65d8..8237384c0 100644 --- a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -7,8 +7,8 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader inference_ops = InferenceOpsLoader().load() -from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3 -from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb +from tests.test_infer.test_kernels.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3 +from tests.test_infer.test_kernels.triton.test_rotary_embdding_unpad import torch_rotary_emb def numpy_allclose(x, y, rtol, atol): diff --git a/tests/test_infer/test_ops/cuda/test_silu_and_mul.py b/tests/test_infer/test_kernels/cuda/test_silu_and_mul.py similarity index 100% rename from tests/test_infer/test_ops/cuda/test_silu_and_mul.py rename to tests/test_infer/test_kernels/cuda/test_silu_and_mul.py diff --git a/tests/test_infer/test_ops/triton/__init__.py b/tests/test_infer/test_kernels/triton/__init__.py similarity index 100% rename from tests/test_infer/test_ops/triton/__init__.py rename to tests/test_infer/test_kernels/triton/__init__.py diff --git a/tests/test_infer/test_ops/triton/kernel_utils.py b/tests/test_infer/test_kernels/triton/kernel_utils.py similarity index 100% rename from tests/test_infer/test_ops/triton/kernel_utils.py rename to tests/test_infer/test_kernels/triton/kernel_utils.py diff --git a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py similarity index 99% rename from tests/test_infer/test_ops/triton/test_context_attn_unpad.py rename to tests/test_infer/test_kernels/triton/test_context_attn_unpad.py index 76785d530..e34fada97 100644 --- a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py @@ -5,7 +5,7 @@ from packaging import version from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import ( +from tests.test_infer.test_kernels.triton.kernel_utils import ( generate_caches_and_block_tables_v2, generate_caches_and_block_tables_v3, torch_attn_ref, diff --git a/tests/test_infer/test_ops/triton/test_decoding_attn.py b/tests/test_infer/test_kernels/triton/test_decoding_attn.py similarity index 97% rename from tests/test_infer/test_ops/triton/test_decoding_attn.py rename to tests/test_infer/test_kernels/triton/test_decoding_attn.py index 616d7868b..24741fecf 100644 --- a/tests/test_infer/test_ops/triton/test_decoding_attn.py +++ b/tests/test_infer/test_kernels/triton/test_decoding_attn.py @@ -6,14 +6,14 @@ from packaging import version from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import ( +from tests.test_infer.test_kernels.triton.kernel_utils import ( convert_kv_unpad_to_padded, create_attention_mask, generate_caches_and_block_tables_v2, generate_caches_and_block_tables_v3, torch_attn_ref, ) -from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask +from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask try: import triton # noqa diff --git a/tests/test_infer/test_ops/triton/test_fused_rotary_embedding.py b/tests/test_infer/test_kernels/triton/test_fused_rotary_embedding.py similarity index 100% rename from tests/test_infer/test_ops/triton/test_fused_rotary_embedding.py rename to tests/test_infer/test_kernels/triton/test_fused_rotary_embedding.py diff --git a/tests/test_infer/test_ops/triton/test_kvcache_copy.py b/tests/test_infer/test_kernels/triton/test_kvcache_copy.py similarity index 99% rename from tests/test_infer/test_ops/triton/test_kvcache_copy.py rename to tests/test_infer/test_kernels/triton/test_kvcache_copy.py index 95126c087..336eb256b 100644 --- a/tests/test_infer/test_ops/triton/test_kvcache_copy.py +++ b/tests/test_infer/test_kernels/triton/test_kvcache_copy.py @@ -4,7 +4,7 @@ from packaging import version from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import ( +from tests.test_infer.test_kernels.triton.kernel_utils import ( generate_caches_and_block_tables_v2, generate_caches_and_block_tables_v3, mock_alloc_single_token, diff --git a/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py b/tests/test_infer/test_kernels/triton/test_rmsnorm_triton.py similarity index 100% rename from tests/test_infer/test_ops/triton/test_rmsnorm_triton.py rename to tests/test_infer/test_kernels/triton/test_rmsnorm_triton.py diff --git a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py similarity index 98% rename from tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py rename to tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py index 87eb38135..570093693 100644 --- a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py @@ -4,7 +4,7 @@ from packaging import version from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from colossalai.kernel.triton import decoding_fused_rotary_embedding -from tests.test_infer.test_ops.triton.kernel_utils import ( +from tests.test_infer.test_kernels.triton.kernel_utils import ( mock_alloc_block_table_and_kvcache_v2, mock_alloc_block_table_and_kvcache_v3, ) diff --git a/tests/test_infer/test_ops/triton/test_xine_copy.py b/tests/test_infer/test_kernels/triton/test_xine_copy.py similarity index 100% rename from tests/test_infer/test_ops/triton/test_xine_copy.py rename to tests/test_infer/test_kernels/triton/test_xine_copy.py diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index 321047706..bca9a1a84 100755 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -164,7 +164,7 @@ def check_cache_manager(test_config): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_cache_manager() diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py index 5d6be5cb1..3d6fc3bdb 100644 --- a/tests/test_infer/test_models/test_baichuan.py +++ b/tests/test_infer/test_models/test_baichuan.py @@ -14,7 +14,6 @@ from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base" BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base" @@ -87,7 +86,7 @@ def run_engine(world_size, **kwargs): def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") if ret: ret[rank] = func_to_run(**kwargs) @@ -99,7 +98,7 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): @parameterize("prompt_template", [None, "baichuan"]) @parameterize("do_sample", [False]) @parameterize("use_cuda_kernel", [True]) -def test_tp_engine(prompt_template, do_sample, use_cuda_kernel): +def check_tp_engine(prompt_template, do_sample, use_cuda_kernel): kwargs1 = { "use_engine": True, "prompt_template": prompt_template, @@ -132,7 +131,7 @@ def test_tp_engine(prompt_template, do_sample, use_cuda_kernel): @pytest.mark.dist @rerun_if_address_is_in_use() def test_inference_engine(): - test_tp_engine() + check_tp_engine() if __name__ == "__main__": diff --git a/tests/test_infer/test_request_handler.py b/tests/test_infer/test_request_handler.py index c7a35ebbe..912fdbf11 100644 --- a/tests/test_infer/test_request_handler.py +++ b/tests/test_infer/test_request_handler.py @@ -90,7 +90,7 @@ def check_request_handler(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_running_list() check_request_handler()