mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[shardformer]Fix lm parallel. (#5480)
* fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix
This commit is contained in:
@@ -331,7 +331,7 @@ class GPT2PipelineForwards:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
if shard_config.enable_tensor_parallelism:
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
@@ -1078,15 +1078,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
if shard_config.enable_tensor_parallelism:
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
else:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
Reference in New Issue
Block a user