mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[shardformer] added development protocol for standardization (#4149)
This commit is contained in:
69
colossalai/shardformer/modeling/bloom.py
Normal file
69
colossalai/shardformer/modeling/bloom.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
|
||||
|
||||
def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,
|
||||
dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
`softmax(l+a) = softmax(l)`. Based on
|
||||
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
||||
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
|
||||
|
||||
Args:
|
||||
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
||||
attention_mask (`torch.Tensor`):
|
||||
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
||||
num_heads (`int`, *required*):
|
||||
number of heads
|
||||
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||
dtype of the output tensor
|
||||
"""
|
||||
import math
|
||||
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size(process_group)
|
||||
num_heads = num_heads * world_size
|
||||
|
||||
batch_size, seq_length = attention_mask.shape
|
||||
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
device=attention_mask.device,
|
||||
dtype=torch.float32)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
device=attention_mask.device,
|
||||
dtype=torch.float32)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1,
|
||||
1 + 2 * num_remaining_heads,
|
||||
2,
|
||||
device=attention_mask.device,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
||||
# => the query_length dimension will then be broadcasted correctly
|
||||
# This is more or less identical to T5's relative position bias:
|
||||
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
||||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
||||
alibi = slopes[..., None] * arange_tensor
|
||||
if dist.is_initialized():
|
||||
num_heads_per_rank = int(num_heads / dist.get_world_size(process_group))
|
||||
offset = dist.get_rank(process_group) * num_heads_per_rank
|
||||
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
|
||||
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
|
||||
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
|
||||
else:
|
||||
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
||||
|
||||
return build_bloom_alibi_tensor
|
Reference in New Issue
Block a user