[Feature] Split cross-entropy computation in SP (#5959)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Wenxuan Tan
2024-09-10 12:06:50 +08:00
committed by GitHub
parent b3db1058ec
commit 8fd25d6e09
25 changed files with 527 additions and 1173 deletions

View File

@@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging
@@ -13,10 +12,13 @@ from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
gather_sp_output,
is_share_sp_tp,
split_forward_gather_backward,
)
from ..layer import dist_cross_entropy
def get_flash_core_attention_forward():
from .chatglm2_6b.modeling_chatglm import CoreAttention
@@ -138,6 +140,7 @@ class ChatGLMPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
force_sp_output_gather: Optional[bool] = True,
):
logger = logging.get_logger(__name__)
output_hidden_states = (
@@ -180,6 +183,15 @@ class ChatGLMPipelineForwards:
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# Support SP + PP
sp_size = shard_config.sequence_parallel_size
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
seq_length *= sp_size
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
@@ -200,29 +212,23 @@ class ChatGLMPipelineForwards:
all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1]
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
)
# Keep the input split across all PP stages
if stage_manager.is_first_stage():
if shard_config.enable_sequence_parallelism:
if sp_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=sp_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
if output_hidden_states:
@@ -248,35 +254,19 @@ class ChatGLMPipelineForwards:
if use_cache:
presents = presents + (kv_cache,)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
# final layer_norm
if self.encoder.post_layer_norm:
hidden_states = self.encoder.final_layernorm(hidden_states)
# Gather seq-wise in the final output stage
if shard_config.enable_sequence_parallelism:
sp_mode = shard_config.sequence_parallelism_mode
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
if not return_dict:
return tuple(
v
@@ -333,6 +323,7 @@ class ChatGLMPipelineForwards:
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
force_sp_output_gather=False,
)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
@@ -340,17 +331,21 @@ class ChatGLMPipelineForwards:
hidden_states = hidden_states[-1:]
lm_logits = self.transformer.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous()
loss = None
if labels is not None:
lm_logits = lm_logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
# ChatGLM doesn't have lm_head split
enable_tp = shard_config.enable_tensor_parallelism
shard_config.enable_tensor_parallelism = False
loss = dist_cross_entropy(
labels,
lm_logits,
shard_config,
self.transformer.output_layer.out_features,
lm_logits.dtype,
)
shard_config.enable_tensor_parallelism = enable_tp
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
@@ -379,6 +374,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
force_sp_output_gather: Optional[bool] = True,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -456,22 +452,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
if sp_mode in ["split_gather"]:
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=sp_group,
grad_scale=sp_size,
fp8_communication=shard_config.fp8_communication,
)
if shard_config.enable_sequence_parallelism:
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
if not return_dict:
return tuple(