mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[refactor] refactor gptq and smoothquant llama (#5012)
* refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda
This commit is contained in:
79
examples/inference/hybrid_gptq_llama.py
Normal file
79
examples/inference/hybrid_gptq_llama.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def run_llama_test(args):
|
||||
quantized_model_dir = args.quantized_path
|
||||
max_batch_size = args.max_batch_size
|
||||
max_input_len = args.max_input_len
|
||||
max_output_len = args.max_output_len
|
||||
micro_batch_size = args.micro_batch_size
|
||||
# load quantized model to the first GPU
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
quantized_model_dir, inject_fused_attention=False, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
tp_size=2,
|
||||
pp_size=2,
|
||||
model=model,
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
micro_batch_size=micro_batch_size,
|
||||
quant="gptq",
|
||||
)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
inputs = data_gen()
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 16
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
output = engine.inference(inputs)
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_llama_test(args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gptq_llama(args):
|
||||
spawn(check_llama, args.tp_size * args.pp_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
|
||||
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
|
||||
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
||||
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
|
||||
args = parser.parse_args()
|
||||
|
||||
test_gptq_llama(args)
|
76
examples/inference/hybrid_smoothquant_llama.py
Normal file
76
examples/inference/hybrid_smoothquant_llama.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
|
||||
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_llama_test(args):
|
||||
quantized_model_dir = args.quantized_path
|
||||
max_batch_size = args.max_batch_size
|
||||
max_input_len = args.max_input_len
|
||||
max_output_len = args.max_output_len
|
||||
micro_batch_size = args.micro_batch_size
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
inputs = data_gen()
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 16
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
model = SmoothLlamaForCausalLM.from_quantized(quantized_model_dir, model_basename="llama-7b")
|
||||
model = model.cuda()
|
||||
|
||||
engine = CaiInferEngine(
|
||||
tp_size=2,
|
||||
pp_size=2,
|
||||
model=model,
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
micro_batch_size=micro_batch_size,
|
||||
quant="smoothquant",
|
||||
)
|
||||
|
||||
output = engine.inference(inputs)
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == 32, f"{len(output)}, {32}"
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_llama_test(args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_smoothquant_llama():
|
||||
spawn(check_llama, args.tp_size * args.pp_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
|
||||
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
|
||||
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
||||
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
|
||||
|
||||
args = parser.parse_args()
|
||||
test_smoothquant_llama()
|
Reference in New Issue
Block a user