mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[Shardformer] Add parallel output for shardformer models(bloom, falcon) (#5702)
* [pre-commit.ci] auto fixes from pre-commit.com hooks * add parallel cross entropy output for falcon model & fix some typos in bloom.py * fix module name error, self.model -> self.transformers in bloom, falcon model * Fix the overflow bug of distributed cross entropy loss function when training with fp16 * add dtype to parallel cross entropy loss function * fix dtype related typos adn prettify the loss.py * fix grad dtype and update dtype mismatch error * fix typo bugs
This commit is contained in:
@@ -389,6 +389,7 @@ class GPT2PipelineForwards:
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
else:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
@@ -1294,6 +1295,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
Reference in New Issue
Block a user