mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -30,9 +30,9 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
|
||||
|
||||
def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,
|
||||
dtype: torch.dtype) -> torch.Tensor:
|
||||
def build_bloom_alibi_tensor(
|
||||
self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
@@ -56,23 +56,23 @@ def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
|
||||
num_heads = num_heads * world_size
|
||||
|
||||
batch_size, seq_length = attention_mask.shape
|
||||
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
device=attention_mask.device,
|
||||
dtype=torch.float32)
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
device=attention_mask.device,
|
||||
dtype=torch.float32)
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
|
||||
device=attention_mask.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1,
|
||||
1 + 2 * num_remaining_heads,
|
||||
2,
|
||||
device=attention_mask.device,
|
||||
dtype=torch.int32)
|
||||
extra_powers = torch.arange(
|
||||
1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32
|
||||
)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||
@@ -87,7 +87,7 @@ def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
|
||||
num_heads_per_rank = int(num_heads / dist.get_world_size(process_group))
|
||||
offset = dist.get_rank(process_group) * num_heads_per_rank
|
||||
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
|
||||
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
|
||||
alibi = alibi[:, offset : num_heads_per_rank + offset, :, :]
|
||||
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
|
||||
else:
|
||||
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
||||
@@ -96,9 +96,9 @@ def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
|
||||
|
||||
|
||||
class BloomPipelineForwards:
|
||||
'''
|
||||
"""
|
||||
This class serves as a micro library for bloom pipeline forwards.
|
||||
'''
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def bloom_model_forward(
|
||||
@@ -117,8 +117,7 @@ class BloomPipelineForwards:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']:
|
||||
|
||||
) -> Union[Tuple[torch.Tensor, ...], "BaseModelOutputWithPastAndCrossAttentions"]:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
@@ -132,20 +131,21 @@ class BloomPipelineForwards:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# add warnings here
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@@ -184,7 +184,8 @@ class BloomPipelineForwards:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is None:
|
||||
@@ -193,7 +194,7 @@ class BloomPipelineForwards:
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2] # source_len
|
||||
past_key_values_length = past_key_values[0][0].shape[2] # source_len
|
||||
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
if attention_mask is None:
|
||||
@@ -213,20 +214,20 @@ class BloomPipelineForwards:
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = split_forward_gather_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]),
|
||||
start=start_idx):
|
||||
for i, (block, layer_past) in enumerate(
|
||||
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
|
||||
):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
@@ -257,14 +258,13 @@ class BloomPipelineForwards:
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + \
|
||||
(outputs[2 if use_cache else 1],)
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
# Add last hidden state
|
||||
@@ -277,7 +277,8 @@ class BloomPipelineForwards:
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
|
||||
)
|
||||
|
||||
# attention_mask is not returned ; presents = past_key_values
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
@@ -288,25 +289,27 @@ class BloomPipelineForwards:
|
||||
)
|
||||
else:
|
||||
# always return dict for imediate stage
|
||||
return {'hidden_states': hidden_states}
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def bloom_for_causal_lm_forward(self: BloomForCausalLM,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
**deprecated_arguments):
|
||||
def bloom_for_causal_lm_forward(
|
||||
self: BloomForCausalLM,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
**deprecated_arguments,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
@@ -328,30 +331,29 @@ class BloomPipelineForwards:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
||||
self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
past_key_values = None
|
||||
all_hidden_states = None
|
||||
all_self_attentions = None
|
||||
all_cross_attentions = None
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
@@ -366,8 +368,9 @@ class BloomPipelineForwards:
|
||||
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
|
||||
shift_labels.view(batch_size * seq_length))
|
||||
loss = loss_fct(
|
||||
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
@@ -381,8 +384,8 @@ class BloomPipelineForwards:
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = transformer_outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
hidden_states = transformer_outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def bloom_for_sequence_classification_forward(
|
||||
@@ -425,10 +428,10 @@ class BloomPipelineForwards:
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
||||
@@ -448,9 +451,6 @@ class BloomPipelineForwards:
|
||||
shard_config=shard_config,
|
||||
)
|
||||
past_key_values = None
|
||||
all_hidden_states = None
|
||||
all_self_attentions = None
|
||||
all_cross_attentions = None
|
||||
if stage_manager.is_last_stage():
|
||||
batch_size = hidden_states.shape[0]
|
||||
# update batch size
|
||||
@@ -468,7 +468,8 @@ class BloomPipelineForwards:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
@@ -506,8 +507,8 @@ class BloomPipelineForwards:
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = transformer_outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
hidden_states = transformer_outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def bloom_for_token_classification_forward(
|
||||
@@ -550,10 +551,10 @@ class BloomPipelineForwards:
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
||||
@@ -573,9 +574,6 @@ class BloomPipelineForwards:
|
||||
shard_config=shard_config,
|
||||
)
|
||||
past_key_values = None
|
||||
all_hidden_states = None
|
||||
all_self_attentions = None
|
||||
all_cross_attentions = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
@@ -588,8 +586,9 @@ class BloomPipelineForwards:
|
||||
labels = labels.to(logits.device)
|
||||
batch_size, seq_length = labels.shape
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels),
|
||||
labels.view(batch_size * seq_length))
|
||||
loss = loss_fct(
|
||||
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[2:]
|
||||
@@ -602,8 +601,8 @@ class BloomPipelineForwards:
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = transformer_outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
hidden_states = transformer_outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def bloom_for_question_answering_forward(
|
||||
@@ -638,10 +637,10 @@ class BloomPipelineForwards:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
outputs = BloomPipelineForwards.bloom_model_forward(
|
||||
@@ -659,10 +658,6 @@ class BloomPipelineForwards:
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
past_key_values = None
|
||||
all_hidden_states = None
|
||||
all_self_attentions = None
|
||||
all_cross_attentions = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
sequence_output = outputs[0]
|
||||
@@ -700,12 +695,11 @@ class BloomPipelineForwards:
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_bloom_flash_attention_forward(enabel_jit_fused=False):
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
@@ -723,7 +717,6 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
|
||||
fused_qkv = self.query_key_value(hidden_states)
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
batch_size, tgt_len, _ = query_layer.size()
|
||||
@@ -750,29 +743,35 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
|
||||
|
||||
tgt_len = key_layer.size()[1]
|
||||
|
||||
attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length),
|
||||
dtype=torch.float32,
|
||||
device=query_layer.device,
|
||||
requires_grad=True)
|
||||
attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1,
|
||||
kv_length) * self.beta
|
||||
attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask,
|
||||
torch.finfo(torch.float32).min)
|
||||
attention_numerical_mask = torch.zeros(
|
||||
(batch_size, self.num_heads, tgt_len, kv_length),
|
||||
dtype=torch.float32,
|
||||
device=query_layer.device,
|
||||
requires_grad=True,
|
||||
)
|
||||
attention_numerical_mask = (
|
||||
attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
|
||||
)
|
||||
attention_numerical_mask = torch.masked_fill(
|
||||
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
|
||||
)
|
||||
|
||||
context_layer = me_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_bias=attention_numerical_mask,
|
||||
scale=self.inv_norm_factor,
|
||||
p=self.attention_dropout.p)
|
||||
context_layer = me_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_bias=attention_numerical_mask,
|
||||
scale=self.inv_norm_factor,
|
||||
p=self.attention_dropout.p,
|
||||
)
|
||||
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
|
||||
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
@@ -787,7 +786,6 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
|
||||
|
||||
|
||||
def get_jit_fused_bloom_attention_forward():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
def forward(
|
||||
@@ -801,7 +799,7 @@ def get_jit_fused_bloom_attention_forward():
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
@@ -867,8 +865,8 @@ def get_jit_fused_bloom_attention_forward():
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
|
||||
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
@@ -885,7 +883,6 @@ def get_jit_fused_bloom_attention_forward():
|
||||
|
||||
|
||||
def get_jit_fused_bloom_mlp_forward():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import BloomMLP
|
||||
|
||||
def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
@@ -896,8 +893,8 @@ def get_jit_fused_bloom_mlp_forward():
|
||||
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
||||
for i in range(self.pretraining_tp):
|
||||
intermediate_output = intermediate_output + F.linear(
|
||||
hidden_states[:, :, int(i * slices):int((i + 1) * slices)],
|
||||
self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)],
|
||||
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
intermediate_output = self.dense_4h_to_h(hidden_states)
|
||||
@@ -908,7 +905,6 @@ def get_jit_fused_bloom_mlp_forward():
|
||||
|
||||
|
||||
def get_jit_fused_bloom_gelu_forward():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import BloomGelu
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
@@ -924,7 +920,6 @@ def get_jit_fused_bloom_gelu_forward():
|
||||
|
||||
|
||||
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
from transformers import BloomModel
|
||||
|
||||
def forward(
|
||||
@@ -951,8 +946,9 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@@ -986,7 +982,8 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
@@ -1009,9 +1006,9 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
)
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
@@ -1020,7 +1017,6 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
@@ -1054,9 +1050,9 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
|
Reference in New Issue
Block a user