[shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix
This commit is contained in:
flybird11111
2024-03-25 17:21:51 +08:00
committed by GitHub
parent 34e909256c
commit 0688d92e2d
5 changed files with 20 additions and 33 deletions

View File

@@ -250,18 +250,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
policy = super().module_policy()
setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output})
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
}
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
policy.update(new_item)
if self.pipeline_stage_manager: