mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[shardformer] llama support DistCrossEntropy (#5176)
* fix aaa fix fix fix * fix * fix * test ci * fix ci fix * llama support dist-cross fix fix fix fix fix fix fix fix * fix * fix * fix fix * test ci * test ci * fix * [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878) * Add finetuning Colossal-Llama-2 example * Add finetuning Colossal-Llama-2 example 2 * Add finetuning Colossal-Llama-2 example and support NEFTuning * Add inference example and refine neftune * Modify readme file * update the imports --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com> * llama support dist-cross fix fix fix fix fix fix fix fix * fix * fix * fix fix * test ci * test ci * fix * fix ci * fix ci --------- Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
|
||||
|
||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_lm_forward_with_dist_cross_entropy
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
||||
@@ -149,7 +149,7 @@ class LlamaPolicy(Policy):
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
@@ -212,9 +212,10 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="lm_head", target_module=Linear1D_Col
|
||||
)
|
||||
]
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
Reference in New Issue
Block a user