mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Feat]Tensor Model Parallel Support For Inference (#5563)
* tensor parallel support naive source * [fix]precision, model load and refactor the framework * add tp unit test * docstring * fix do_sample
This commit is contained in:
@@ -2,8 +2,12 @@
|
||||
Utils for model inference
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||
@@ -49,3 +53,52 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
|
||||
|
||||
|
||||
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
||||
"""
|
||||
Check whether the checkpoint has an index file.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): path to the checkpoint.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path)
|
||||
"""
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if checkpoint_path.is_file():
|
||||
# check if it is .index.json
|
||||
reg = re.compile("(.*?).index((\..*)?).json")
|
||||
if reg.fullmatch(checkpoint_path.name) is not None:
|
||||
return True, checkpoint_path
|
||||
else:
|
||||
return False, None
|
||||
elif checkpoint_path.is_dir():
|
||||
index_files = list(checkpoint_path.glob("*.index.*json"))
|
||||
|
||||
for index_file in index_files:
|
||||
if "safetensors" in index_file.__str__():
|
||||
return True, index_file.__str__() # return the safetensors file first
|
||||
|
||||
if len(index_files) == 1:
|
||||
return True, index_files[0]
|
||||
else:
|
||||
assert (
|
||||
len(index_files) == 1
|
||||
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
|
||||
return False, None
|
||||
else:
|
||||
raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module):
|
||||
"""Calculates the total size of the model weights (including biases) in bytes.
|
||||
Args:
|
||||
model: The PyTorch model to analyze.
|
||||
Returns:
|
||||
The total size of the model weights in bytes.
|
||||
"""
|
||||
total_size = 0
|
||||
for key, param in model.named_parameters():
|
||||
total_size += param.element_size() * param.numel()
|
||||
return total_size / (1024**3)
|
||||
|
Reference in New Issue
Block a user