[inference] Adapted to Rotary Embedding and RMS Norm (#5283)

* adapted to rotary_embedding

* adapted to nopad rms norm

* fix bugs in benchmark

* fix flash_decoding.py
This commit is contained in:
yuehuayingxueluo
2024-01-22 10:55:34 +08:00
committed by GitHub
parent 6e487e7d3c
commit bfff9254ac
5 changed files with 140 additions and 43 deletions

View File

@@ -95,10 +95,13 @@ def benchmark_inference(args):
if args.dtype == "fp16":
model = model.half()
elif args.dtype == "fp16":
elif args.dtype == "bf16":
model = model.to(torch.bfloat16)
mbsz = args.mbsz
if args.continous_batching:
mbsz = args.mbsz
else:
mbsz = args.batch_size
if args.mode == "caiinference":
inference_config = InferenceConfig(
dtype=args.dtype,
@@ -205,5 +208,8 @@ if __name__ == "__main__":
choices=["caiinference", "transformers"],
help="decide which inference framework to run",
)
parser.add_argument(
"-cb", "--continous_batching", default=False, action="store_true", help="enable continous batching"
)
args = parser.parse_args()
benchmark(args)