mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[inference] Adapted to Rotary Embedding and RMS Norm (#5283)
* adapted to rotary_embedding * adapted to nopad rms norm * fix bugs in benchmark * fix flash_decoding.py
This commit is contained in:
@@ -95,10 +95,13 @@ def benchmark_inference(args):
|
||||
|
||||
if args.dtype == "fp16":
|
||||
model = model.half()
|
||||
elif args.dtype == "fp16":
|
||||
elif args.dtype == "bf16":
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
mbsz = args.mbsz
|
||||
if args.continous_batching:
|
||||
mbsz = args.mbsz
|
||||
else:
|
||||
mbsz = args.batch_size
|
||||
if args.mode == "caiinference":
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
@@ -205,5 +208,8 @@ if __name__ == "__main__":
|
||||
choices=["caiinference", "transformers"],
|
||||
help="decide which inference framework to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-cb", "--continous_batching", default=False, action="store_true", help="enable continous batching"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
benchmark(args)
|
||||
|
Reference in New Issue
Block a user