[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -30,9 +30,9 @@ logger = logging.get_logger(__name__)
def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,
dtype: torch.dtype) -> torch.Tensor:
def build_bloom_alibi_tensor(
self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
@@ -56,23 +56,23 @@ def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
num_heads = num_heads * world_size
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1,
1 + 2 * num_remaining_heads,
2,
device=attention_mask.device,
dtype=torch.int32)
extra_powers = torch.arange(
1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
@@ -87,7 +87,7 @@ def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
num_heads_per_rank = int(num_heads / dist.get_world_size(process_group))
offset = dist.get_rank(process_group) * num_heads_per_rank
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
alibi = alibi[:, offset : num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
else:
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
@@ -96,9 +96,9 @@ def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
class BloomPipelineForwards:
'''
"""
This class serves as a micro library for bloom pipeline forwards.
'''
"""
@staticmethod
def bloom_model_forward(
@@ -117,8 +117,7 @@ class BloomPipelineForwards:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']:
) -> Union[Tuple[torch.Tensor, ...], "BaseModelOutputWithPastAndCrossAttentions"]:
logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
@@ -132,20 +131,21 @@ class BloomPipelineForwards:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# add warnings here
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -184,7 +184,8 @@ class BloomPipelineForwards:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if past_key_values is None:
@@ -193,7 +194,7 @@ class BloomPipelineForwards:
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2] # source_len
past_key_values_length = past_key_values[0][0].shape[2] # source_len
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
@@ -213,20 +214,20 @@ class BloomPipelineForwards:
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]),
start=start_idx):
for i, (block, layer_past) in enumerate(
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
@@ -257,14 +258,13 @@ class BloomPipelineForwards:
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + \
(outputs[2 if use_cache else 1],)
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
if stage_manager.is_last_stage():
# Add last hidden state
@@ -277,7 +277,8 @@ class BloomPipelineForwards:
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
)
# attention_mask is not returned ; presents = past_key_values
return BaseModelOutputWithPastAndCrossAttentions(
@@ -288,25 +289,27 @@ class BloomPipelineForwards:
)
else:
# always return dict for imediate stage
return {'hidden_states': hidden_states}
return {"hidden_states": hidden_states}
@staticmethod
def bloom_for_causal_lm_forward(self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments):
def bloom_for_causal_lm_forward(
self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -328,30 +331,29 @@ class BloomPipelineForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config)
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
all_cross_attentions = None
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@@ -366,8 +368,9 @@ class BloomPipelineForwards:
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
@@ -381,8 +384,8 @@ class BloomPipelineForwards:
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def bloom_for_sequence_classification_forward(
@@ -425,10 +428,10 @@ class BloomPipelineForwards:
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
@@ -448,9 +451,6 @@ class BloomPipelineForwards:
shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
all_cross_attentions = None
if stage_manager.is_last_stage():
batch_size = hidden_states.shape[0]
# update batch size
@@ -468,7 +468,8 @@ class BloomPipelineForwards:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
@@ -506,8 +507,8 @@ class BloomPipelineForwards:
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def bloom_for_token_classification_forward(
@@ -550,10 +551,10 @@ class BloomPipelineForwards:
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
@@ -573,9 +574,6 @@ class BloomPipelineForwards:
shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
all_cross_attentions = None
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
@@ -588,8 +586,9 @@ class BloomPipelineForwards:
labels = labels.to(logits.device)
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels),
labels.view(batch_size * seq_length))
loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
)
if not return_dict:
output = (logits,) + transformer_outputs[2:]
@@ -602,8 +601,8 @@ class BloomPipelineForwards:
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def bloom_for_question_answering_forward(
@@ -638,10 +637,10 @@ class BloomPipelineForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
outputs = BloomPipelineForwards.bloom_model_forward(
@@ -659,10 +658,6 @@ class BloomPipelineForwards:
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
all_cross_attentions = None
if stage_manager.is_last_stage():
sequence_output = outputs[0]
@@ -700,12 +695,11 @@ class BloomPipelineForwards:
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
def get_bloom_flash_attention_forward(enabel_jit_fused=False):
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
@@ -723,7 +717,6 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _ = query_layer.size()
@@ -750,29 +743,35 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
tgt_len = key_layer.size()[1]
attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length),
dtype=torch.float32,
device=query_layer.device,
requires_grad=True)
attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1,
kv_length) * self.beta
attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask,
torch.finfo(torch.float32).min)
attention_numerical_mask = torch.zeros(
(batch_size, self.num_heads, tgt_len, kv_length),
dtype=torch.float32,
device=query_layer.device,
requires_grad=True,
)
attention_numerical_mask = (
attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
)
attention_numerical_mask = torch.masked_fill(
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
)
context_layer = me_attention(query_layer,
key_layer,
value_layer,
attn_bias=attention_numerical_mask,
scale=self.inv_norm_factor,
p=self.attention_dropout.p)
context_layer = me_attention(
query_layer,
key_layer,
value_layer,
attn_bias=attention_numerical_mask,
scale=self.inv_norm_factor,
p=self.attention_dropout.p,
)
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
@@ -787,7 +786,6 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
def get_jit_fused_bloom_attention_forward():
from transformers.models.bloom.modeling_bloom import BloomAttention
def forward(
@@ -801,7 +799,7 @@ def get_jit_fused_bloom_attention_forward():
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
@@ -867,8 +865,8 @@ def get_jit_fused_bloom_attention_forward():
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
@@ -885,7 +883,6 @@ def get_jit_fused_bloom_attention_forward():
def get_jit_fused_bloom_mlp_forward():
from transformers.models.bloom.modeling_bloom import BloomMLP
def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
@@ -896,8 +893,8 @@ def get_jit_fused_bloom_mlp_forward():
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices):int((i + 1) * slices)],
self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)],
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
@@ -908,7 +905,6 @@ def get_jit_fused_bloom_mlp_forward():
def get_jit_fused_bloom_gelu_forward():
from transformers.models.bloom.modeling_bloom import BloomGelu
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
@@ -924,7 +920,6 @@ def get_jit_fused_bloom_gelu_forward():
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
from transformers import BloomModel
def forward(
@@ -951,8 +946,9 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -986,7 +982,8 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Compute alibi tensor: check build_alibi_tensor documentation
@@ -1009,9 +1006,9 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
)
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
@@ -1020,7 +1017,6 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
@@ -1054,9 +1050,9 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)