[Shardformer] Add parallel output for shardformer models(bloom, falcon) (#5702)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* add parallel cross entropy output for falcon model & fix some typos in bloom.py

* fix module name error, self.model -> self.transformers in bloom, falcon model

* Fix the overflow bug of distributed cross entropy loss function when training with fp16

* add dtype to parallel cross entropy loss function

* fix dtype related typos adn prettify the loss.py

* fix grad dtype and update dtype mismatch error

* fix typo bugs
This commit is contained in:
Haze188
2024-05-21 11:07:13 +08:00
committed by GitHub
parent 9d83c6d715
commit 22ce873c3f
9 changed files with 230 additions and 17 deletions

View File

@@ -389,6 +389,7 @@ class GPT2PipelineForwards:
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(shift_logits, shift_labels)
@@ -1294,6 +1295,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
if not return_dict: