Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
char-1ee
2024-06-07 08:28:19 +00:00
parent eec77e5702
commit 5f398fc000
11 changed files with 238 additions and 136 deletions

View File

@@ -1,17 +1,17 @@
"""
Utils for model inference
"""
import math
import os
import re
import math
from pathlib import Path
from typing import Optional, Tuple
import torch
from torch import nn
from colossalai.testing import free_port
from colossalai.logging import get_dist_logger
from colossalai.testing import free_port
logger = get_dist_logger(__name__)
@@ -122,11 +122,11 @@ def find_available_ports(num: int):
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
"""
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
Args:
num_heads (int): The number of attention heads.
device (torch.device): The device to use.
Returns:
torch.Tensor: The Alibi slopes.
"""
@@ -142,20 +142,17 @@ def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
def can_use_flash_attn2(dtype: torch.dtype) -> bool:
"""
Check flash attention2 availability.
"""
if dtype not in (torch.float16, torch.bfloat16):
logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.")
return False
try:
from flash_attn import __version__
logger.info(f"flash_attn2 version {__version__}.")
return True
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
return False
return False