From e96a0761ea697cc2d372b7b3d41df6b304353859 Mon Sep 17 00:00:00 2001 From: Guangyao Zhang Date: Thu, 29 Aug 2024 14:49:23 +0800 Subject: [PATCH] [FP8] unsqueeze scale to make it compatible with torch.compile (#6040) --- colossalai/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 5b606616e..c022fab15 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -56,7 +56,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - scale_inv = 1.0 / scale ret = (scale * inp.float()).to(fp8_type) - return ret, scale_inv + return ret, torch.unsqueeze(scale_inv, dim=0) def cast_from_fp8(