ColossalAI/colossalai/kernel/extensions/flash_attention/utils.py
Hongxin Liu d202cc28c0
[npu] change device to accelerator api (#5239)
* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
2024-01-09 10:20:05 +08:00

92 lines
2.8 KiB
Python

import enum
from dataclasses import dataclass
from typing import Iterable, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.accelerator import get_accelerator
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3