mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
Refactor modeling by adding attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@ Utils for model inference
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -10,6 +11,9 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.testing import free_port
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||
@@ -113,3 +117,45 @@ def find_available_ports(num: int):
|
||||
print(f"An OS error occurred: {e}")
|
||||
raise RuntimeError("Error finding available ports")
|
||||
return free_ports
|
||||
|
||||
|
||||
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
||||
|
||||
Args:
|
||||
num_heads (int): The number of attention heads.
|
||||
device (torch.device): The device to use.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The Alibi slopes.
|
||||
"""
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
|
||||
slopes = torch.pow(base, powers)
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
||||
"""
|
||||
Check flash attention2 availability.
|
||||
"""
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.")
|
||||
return False
|
||||
|
||||
try:
|
||||
from flash_attn import __version__
|
||||
logger.info(f"flash_attn2 version {__version__}.")
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
return False
|
Reference in New Issue
Block a user