[inference] refactor examples and fix schedule (#5077)

* [setup] refactor infer setup

* [hotfix] fix infenrece behavior on 1 1 gpu

* [exmaple] refactor inference examples
This commit is contained in:
Hongxin Liu
2023-11-21 10:46:03 +08:00
committed by GitHub
parent 4e3959d316
commit 1cd7efc520
9 changed files with 209 additions and 274 deletions

View File

@@ -7,11 +7,17 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
import colossalai
from colossalai.inference import InferenceEngine
from colossalai.testing import spawn
from colossalai.utils.device import get_current_device
INPUT_TEXTS = [
"What is the longest river in the world?",
"Explain the difference between process and thread in compouter science.",
]
def run_inference(args):
llama_model_path = args.model_path
llama_tokenize_path = args.tokenizer_path
llama_tokenize_path = args.tokenizer_path or args.model_path
max_input_len = args.max_input_len
max_output_len = args.max_output_len
@@ -22,11 +28,10 @@ def run_inference(args):
rank = dist.get_rank()
tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left")
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.pad_token_id = tokenizer.eos_token_id
if args.quant is None:
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.unk_token_id)
model = model.half()
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.pad_token_id)
elif args.quant == "gptq":
from auto_gptq import AutoGPTQForCausalLM
@@ -45,18 +50,21 @@ def run_inference(args):
model=model,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
micro_batch_size=micro_batch_size,
quant=args.quant,
dtype=args.dtype,
)
input_tokens = {
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
}
inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True)
inputs = {k: v.to(get_current_device()) for k, v in inputs.items()}
outputs = engine.generate(inputs)
outputs = engine.generate(input_tokens)
if rank == 0:
print(tokenizer.batch_decode(outputs))
output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for input_text, output_text in zip(INPUT_TEXTS, output_texts):
print(f"Input: {input_text}")
print(f"Output: {output_text}")
def run_tp_pipeline_inference(rank, world_size, port, args):
@@ -67,8 +75,8 @@ def run_tp_pipeline_inference(rank, world_size, port, args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True)
parser.add_argument("--tokenizer_path", type=str, help="Tokenizer path", required=True)
parser.add_argument("-i", "--input", default="What is the longest river in the world?")
parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None)
parser.add_argument(
"-q",
"--quant",
@@ -78,12 +86,13 @@ if __name__ == "__main__":
help="quantization type: 'gptq' or 'smoothquant'",
)
parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name")
parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size")
parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Pipeline parallel size")
parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size")
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length")
parser.add_argument("--max_input_len", type=int, default=2048, help="Maximum input length")
parser.add_argument("--max_output_len", type=int, default=64, help="Maximum output length")
parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size")
parser.add_argument("--dtype", default="fp16", type=str)
args = parser.parse_args()
spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args)