mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference] Adapt Baichuan2-13B TP (#5659)
* adapt to baichuan2 13B * add baichuan2 13B TP * update baichuan tp logic * rm unused code * Fix TP logic * fix alibi slopes tp logic * rm nn.Module * Polished the code. * change BAICHUAN_MODEL_NAME_OR_PATH * Modified the logic for loading Baichuan weights. * fix typos
This commit is contained in:
@@ -112,11 +112,23 @@ class InferenceEngine:
|
||||
model_policy (Policy): the policy to replace the model
|
||||
"""
|
||||
|
||||
casuallm = None
|
||||
if isinstance(model_or_path, str):
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
model = _supported_models[arch](hf_config)
|
||||
if arch in _supported_models.keys():
|
||||
casuallm = _supported_models[arch](hf_config)
|
||||
if isinstance(casuallm, AutoModelForCausalLM):
|
||||
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
|
||||
)
|
||||
else:
|
||||
model = _supported_models[arch](hf_config)
|
||||
else:
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
@@ -164,7 +176,7 @@ class InferenceEngine:
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
|
Reference in New Issue
Block a user