ColossalAI/examples/inference/smoothquant_llama.py
Xu Kai 611a5a80ca
[inference] Add smmoothquant for llama (#4904)
* [inference] add int8 rotary embedding kernel for smoothquant (#4843)

* [inference] add smoothquant llama attention (#4850)

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  (#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant (#4854)

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama (#4861)

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant (#4895)

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes (#4902)

* rafactor code

* add license

* add torch-int and smoothquant license
2023-10-16 11:28:44 +08:00

70 lines
2.1 KiB
Python

import argparse
import os
import torch
from datasets import load_dataset
from transformers import LlamaTokenizer
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
def build_model_and_tokenizer(model_name):
tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512)
kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"}
model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs)
model = model.to(torch.float32)
return model, tokenizer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, help="model name")
parser.add_argument(
"--output-path",
type=str,
help="where to save the checkpoint",
)
parser.add_argument(
"--dataset-path",
type=str,
help="location of the calibration dataset",
)
parser.add_argument("--num-samples", type=int, default=512)
parser.add_argument("--seq-len", type=int, default=512)
args = parser.parse_args()
return args
@torch.no_grad()
def main():
args = parse_args()
model_path = args.model_name
dataset_path = args.dataset_path
output_path = args.output_path
num_samples = 10
seq_len = 512
model, tokenizer = build_model_and_tokenizer(model_path)
if not os.path.exists(dataset_path):
print(f"Cannot find the dataset at {args.dataset_path}")
raise FileNotFoundError
dataset = load_dataset("json", data_files=dataset_path, split="train")
model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len)
model = model.cuda()
model.save_quantized(output_path, model_basename="llama-7b")
model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b")
model = model.cuda()
generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True)
input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda")
out = model.generate(**input_tokens, **generate_kwargs)
text = tokenizer.batch_decode(out)
print("out is:", text)
if __name__ == "__main__":
main()