mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[inference] add reference and fix some bugs (#4937)
* add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com>
This commit is contained in:
@@ -132,6 +132,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
|
||||
mean_scale = np.mean([v["input"] for v in act_dict.values()])
|
||||
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
|
||||
|
||||
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
|
||||
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
@@ -163,6 +164,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
|
||||
|
||||
return act_scales
|
||||
|
||||
# Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
|
||||
@torch.no_grad()
|
||||
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
|
||||
if not isinstance(fcs, list):
|
||||
@@ -189,6 +191,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
|
||||
def create_quantized_model(model):
|
||||
raise NotImplementedError("Not implement create_quantized_model method")
|
||||
|
||||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
|
||||
def save_quantized(
|
||||
self,
|
||||
save_dir: str,
|
||||
@@ -249,6 +252,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
|
||||
|
||||
self.model.config.save_pretrained(save_dir)
|
||||
|
||||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_dir: str,
|
||||
@@ -260,6 +264,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
|
||||
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||
|
||||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -354,6 +359,7 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
|
||||
|
||||
return cls(model, False)
|
||||
|
||||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
|
||||
@classmethod
|
||||
def from_quantized(
|
||||
cls,
|
||||
|
@@ -62,6 +62,7 @@ class W8A8BFP32O32LinearSiLU(torch.nn.Module):
|
||||
return int8_module
|
||||
|
||||
|
||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
class W8A8B8O8Linear(torch.nn.Module):
|
||||
# For qkv_proj
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
@@ -117,6 +118,7 @@ class W8A8B8O8Linear(torch.nn.Module):
|
||||
return int8_module
|
||||
|
||||
|
||||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
class W8A8BFP32OFP32Linear(torch.nn.Module):
|
||||
# For fc2 and out_proj
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
|
@@ -419,6 +419,7 @@ class LlamaApplyRotary(nn.Module):
|
||||
return x_embed
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -559,6 +560,7 @@ def init_to_get_rotary(config, base=10000, use_elem=False):
|
||||
return _cos_cached, _sin_cached
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
def llama_model_forward(
|
||||
self,
|
||||
@@ -729,6 +731,7 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
|
||||
def __init__(self, model: PreTrainedModel, quantized: bool = False):
|
||||
super().__init__(model, quantized)
|
||||
|
||||
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
|
||||
def get_act_dict(
|
||||
self,
|
||||
tokenizer,
|
||||
|
@@ -21,6 +21,8 @@ _supported_models = [
|
||||
"BloomForCausalLM",
|
||||
"ChatGLMModel",
|
||||
"ChatGLMForConditionalGeneration",
|
||||
"LlamaGPTQForCausalLM",
|
||||
"BloomGPTQForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
@@ -213,11 +215,14 @@ class TPInferEngine:
|
||||
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
||||
model_name = model.__class__.__name__
|
||||
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
||||
|
||||
model = model.model if self.shard_config.inference_gptq else model
|
||||
|
||||
policy = get_autopolicy(model, inference_only=True)
|
||||
self.model, _ = shardformer.optimize(model, policy)
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
self._post_init_gptq_buffer(model)
|
||||
self._post_init_gptq_buffer(self.model)
|
||||
|
||||
self.model = self.model.cuda()
|
||||
|
||||
|
Reference in New Issue
Block a user