[shardformer] llama support DistCrossEntropy (#5176)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix

* llama support dist-cross

fix

fix

fix

fix

fix

fix

fix

fix

* fix

* fix

* fix

fix

* test ci

* test ci

* fix

* [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)

* Add finetuning Colossal-Llama-2 example

* Add finetuning Colossal-Llama-2 example 2

* Add finetuning Colossal-Llama-2 example and support NEFTuning

* Add inference example and refine neftune

* Modify readme file

* update the imports

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* llama support dist-cross

fix

fix

fix

fix

fix

fix

fix

fix

* fix

* fix

* fix

fix

* test ci

* test ci

* fix

* fix ci

* fix ci

---------

Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
This commit is contained in:
flybird11111
2023-12-13 01:39:14 +08:00
committed by GitHub
parent cefdc32615
commit 79718fae04
5 changed files with 143 additions and 13 deletions

View File

@@ -8,7 +8,7 @@ from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_lm_forward_with_dist_cross_entropy
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
@@ -149,7 +149,7 @@ class LlamaPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
@@ -212,9 +212,10 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="lm_head", target_module=Linear1D_Col
)
]
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
)
}
policy.update(new_item)