mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[shardformer] update transformers (#5583)
* flash_attention forward upgrade * llama_model_forward * remove useless comment * update the requirements.txt * add the transformers version requirements * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update_falcon (#5520) * [shardformer] update mistral model (#5511) * [shardformer] update gpt2 (#5502) * [shardformer] update gptj model (#5503) * [shardformer] update opt (#5522) * [shardformer] update t5 model (#5524) * [shardformer] update whisper model (#5529) * [shardformer] update vit model (#5530) * update vit model * remove the output_hidden_states * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * fix conflicts * [doc] fix ColossalMoE readme (#5599) * fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * merge with main * merge with main * llama_model_forward * remove useless comment * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update mistral model (#5511) * [shardformer] update opt (#5522) * [shardformer] update whisper model (#5529) * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606) * fix no pad token bug * fixed some auto parallel codegen bug, but might not run on torch 2.1 --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] fix pipeline grad ckpt (#5620) * [shardformer] fix pipeline grad ckpt * [shardformer] fix whisper (#5628) * [test] fix llama model test * fix the opt upgrade (#5634) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * Fix shardformer upgrade (#5640) * fix llama model * fix the mistral * fix the shardformer model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]support pipeline parallelism for mistral. (#5642) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Support LLaMA-3 CPT and ST (#5619) * support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [exampe] update llama example (#5626) * [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readme * [example] llama3 (#5631) * release llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * support pp for mistral * fix * fix fix fix * fix --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@@ -42,7 +43,7 @@ def _get_attention_mask(
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self.decoder._prepare_decoder_attention_mask(
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
@@ -57,6 +58,20 @@ class OPTPipelineForwards:
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
@staticmethod
|
||||
def opt_model_forward(
|
||||
self: OPTModel,
|
||||
@@ -112,7 +127,7 @@ class OPTPipelineForwards:
|
||||
inputs_embeds = decoder.project_in(inputs_embeds)
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
inputs_embeds.dtype
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states shouldn't be None for intermediate stages.")
|
||||
@@ -125,12 +140,25 @@ class OPTPipelineForwards:
|
||||
# required mask seq length can be calculated via length of past
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
if self.decoder._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
attention_mask = (
|
||||
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
if attention_mask is None
|
||||
else attention_mask
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
causal_attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
@@ -205,20 +233,14 @@ class OPTPipelineForwards:
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if decoder.gradient_checkpointing and decoder.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
Reference in New Issue
Block a user