mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
Pass inference model shard configs for module init
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.testing import free_port
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import free_port
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
@@ -122,11 +122,11 @@ def find_available_ports(num: int):
|
||||
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
||||
|
||||
|
||||
Args:
|
||||
num_heads (int): The number of attention heads.
|
||||
device (torch.device): The device to use.
|
||||
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The Alibi slopes.
|
||||
"""
|
||||
@@ -142,20 +142,17 @@ def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
|
||||
|
||||
def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
||||
"""
|
||||
Check flash attention2 availability.
|
||||
"""
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn import __version__
|
||||
logger.info(f"flash_attn2 version {__version__}.")
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
return False
|
||||
return False
|
||||
|
Reference in New Issue
Block a user