mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Feat]Tensor Model Parallel Support For Inference (#5563)
* tensor parallel support naive source * [fix]precision, model load and refactor the framework * add tp unit test * docstring * fix do_sample
This commit is contained in:
@@ -5,8 +5,17 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from torch import distributed as dist
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
@@ -14,6 +23,8 @@ from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
@@ -25,10 +36,10 @@ __all__ = ["InferenceEngine"]
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = [
|
||||
"LlamaForCausalLM",
|
||||
"BaichuanForCausalLM",
|
||||
]
|
||||
_supported_models = {
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||
}
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
@@ -39,7 +50,7 @@ class InferenceEngine:
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
@@ -48,53 +59,25 @@ class InferenceEngine:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: InferenceConfig,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
self.model = model
|
||||
self.device = torch.device("cuda")
|
||||
self.dtype = inference_config.dtype
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.high_precision = inference_config.high_precision
|
||||
self._verify_args()
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
model.eval()
|
||||
model = model.to(self.dtype)
|
||||
model = model.to(self.device)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = False
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
if model_policy is None:
|
||||
if self.inference_config.pad_input:
|
||||
model_type = "padding_" + self.model_config.model_type
|
||||
else:
|
||||
model_type = "nopadding_" + self.model_config.model_type
|
||||
model_policy = model_policy_map[model_type]()
|
||||
|
||||
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
None,
|
||||
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
|
||||
)
|
||||
|
||||
self.verbose = verbose
|
||||
if verbose:
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.logger = get_dist_logger(__name__)
|
||||
|
||||
self.init_model(model_or_path, model_policy)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||
@@ -111,6 +94,91 @@ class InferenceEngine:
|
||||
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = False
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model
|
||||
"""
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
model = _supported_models[arch](hf_config)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
)
|
||||
else:
|
||||
model = model_or_path
|
||||
|
||||
self.model_config = model.config
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
model = model.to(self.dtype).eval()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
if self.inference_config.pad_input:
|
||||
model_type = "padding_" + self.model_config.model_type
|
||||
else:
|
||||
model_type = "nopadding_" + self.model_config.model_type
|
||||
model_policy = model_policy_map[model_type]()
|
||||
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = ModelWrapper(model).to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(model_or_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
@@ -194,8 +262,11 @@ class InferenceEngine:
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
if self.model.__class__.__name__ not in _supported_models:
|
||||
raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
|
||||
if isinstance(self.model, ModelWrapper):
|
||||
model = self.model.module
|
||||
assert (
|
||||
model.__class__.__name__ in _supported_models.keys()
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
|
140
colossalai/inference/core/plugin.py
Normal file
140
colossalai/inference/core/plugin.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.checkpoint_io.general_checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class InferCheckpoint_io(GeneralCheckpointIO):
|
||||
"""
|
||||
This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO.
|
||||
Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class specifically for Inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.verbose = verbose
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model = model.unwrap()
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
strict = False
|
||||
|
||||
# Load params & buffers to model.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
|
||||
missing_keys = []
|
||||
missing_file_keys = []
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
missing_file_keys.append(name)
|
||||
return
|
||||
filename = weight_map[name]
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
return
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
|
||||
load_state_dict_into_model(
|
||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||
)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
for name, _ in model.named_parameters():
|
||||
_load(name)
|
||||
|
||||
# Load buffers.
|
||||
non_persistent_buffers = set()
|
||||
for n, m in model.named_modules():
|
||||
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in non_persistent_buffers:
|
||||
_load(name)
|
||||
|
||||
# Load extra states.
|
||||
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
|
||||
if (
|
||||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
_load(extra_state_key)
|
||||
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
if len(missing_keys) == 0:
|
||||
raise RuntimeError(
|
||||
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
|
||||
)
|
||||
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
remain_keys = remain_keys.union(set(missing_file_keys))
|
||||
if len(remain_keys) > 0:
|
||||
if strict:
|
||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.coordinator.is_master():
|
||||
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
return NotImplementedError
|
@@ -140,7 +140,7 @@ class RequestHandler:
|
||||
|
||||
fd_inter_tensor.initialize(
|
||||
max_batch_size=max_n_tokens,
|
||||
num_attn_heads=model_config.num_attention_heads,
|
||||
num_attn_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
@@ -150,7 +150,7 @@ class RequestHandler:
|
||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||
# which may cause bugs and this issue should be fixed later.
|
||||
self.running_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
@@ -161,7 +161,7 @@ class RequestHandler:
|
||||
device=device,
|
||||
)
|
||||
self.prefill_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
|
Reference in New Issue
Block a user